Files
BiPy/main.cpp
T
2026-05-15 18:45:50 +07:00

984 lines
41 KiB
C++

// ============================================================================
// Xenith Studio - Node Editor v3.1 (Исправлено обучение + Загрузка из файла + Фикс портов)
// ============================================================================
#define IMGUI_DEFINE_MATH_OPERATORS
#include "imgui.h"
#include "imgui_internal.h"
#include "imgui_impl_glfw.h"
#include "imgui_impl_opengl3.h"
#include <GLFW/glfw3.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
#include <thread>
#include <atomic>
#include <mutex>
#include <cmath>
#include <sstream>
#include <iomanip>
#include "Xenith/core.hpp"
#include "Xenith/token/token.hpp"
// ============================================================================
// NODE EDITOR ENGINE
// ============================================================================
namespace NodeEditor {
enum class PortType { Input, Output };
enum class NodeType { Input, Hidden, Output, Splitter, Merger };
struct Port {
std::string name;
PortType type;
int index;
Port(const std::string& n, PortType t, int idx = 0) : name(n), type(t), index(idx) {}
};
struct Node {
int id;
std::string title;
NodeType type;
ImVec2 pos;
ImVec2 size;
bool selected;
bool dragging;
ImVec2 dragOffset;
int layerSize;
int branchCount;
int activeBranch;
std::vector<Port> inputs;
std::vector<Port> outputs;
Node(int id_, const std::string& t_, NodeType type_)
: id(id_), title(t_), type(type_), pos(0,0), size(180,100),
selected(false), dragging(false), layerSize(128), branchCount(2), activeBranch(-1) {
UpdatePorts();
}
void UpdatePorts() {
inputs.clear(); outputs.clear();
switch(type) {
case NodeType::Input:
// ИСПРАВЛЕНО: 0 входов, 1 выход
outputs.push_back(Port("Out", PortType::Output));
break;
case NodeType::Hidden:
// ИСПРАВЛЕНО: 1 вход, 1 выход
inputs.push_back(Port("In", PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
break;
case NodeType::Output:
// ИСПРАВЛЕНО: 1 вход, 0 выходов
inputs.push_back(Port("In", PortType::Input));
break;
case NodeType::Splitter:
inputs.push_back(Port("In", PortType::Input));
for(int i=0;i<branchCount;i++) outputs.push_back(Port("Br "+std::to_string(i), PortType::Output));
break;
case NodeType::Merger:
for(int i=0;i<branchCount;i++) inputs.push_back(Port("Br "+std::to_string(i), PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
break;
}
}
ImVec2 GetPortScreenPos(int portIdx, bool isInput, const ImVec2& canvasPos, const ImVec2& pan) const {
float y = pos.y + 30 + portIdx * 22;
return canvasPos + (isInput ? ImVec2(pos.x - 6, y) : ImVec2(pos.x + size.x + 6, y)) + pan;
}
};
struct Connection {
int fromNode, fromPort, toNode, toPort;
Connection(int fn, int fp, int tn, int tp) : fromNode(fn), fromPort(fp), toNode(tn), toPort(tp) {}
};
struct GraphState {
std::vector<Node> nodes;
std::vector<Connection> connections;
int nextId = 0;
int selectedNode = -1;
int hoveredPortNode = -1, hoveredPortIdx = -1;
PortType hoveredPortType = PortType::Input;
bool creatingConn = false;
int connStartNode = -1, connStartPort = -1;
PortType connStartType = PortType::Output;
ImVec2 connMousePos;
ImVec2 pan = ImVec2(100, 80);
bool panning = false;
ImVec2 panStart;
};
ImU32 GetNodeColor(NodeType type, bool selected) {
if(selected) return IM_COL32(255,255,100,255);
switch(type) {
case NodeType::Input: return IM_COL32(80,200,120,255);
case NodeType::Hidden: return IM_COL32(80,140,220,255);
case NodeType::Output: return IM_COL32(220,80,80,255);
case NodeType::Splitter: return IM_COL32(220,180,50,255);
case NodeType::Merger: return IM_COL32(180,80,220,255);
default: return IM_COL32(120,120,120,255);
}
}
void DrawNode(ImDrawList* dl, const Node& n, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
ImVec2 scr = canvasPos + n.pos + pan;
dl->AddRectFilled(scr, scr + n.size, IM_COL32(35,38,48,255), 6);
dl->AddRect(scr, scr + n.size, GetNodeColor(n.type, n.selected), 6, 0, 2.0f);
dl->AddText(scr + ImVec2(10,4), IM_COL32(255,255,255,255), n.title.c_str());
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), std::to_string(n.layerSize).c_str());
if(n.type == NodeType::Splitter || n.type == NodeType::Merger) {
dl->AddText(scr + ImVec2(10, n.size.y-18), IM_COL32(255,220,100,255), (" x"+std::to_string(n.branchCount)).c_str());
}
for(size_t i=0; i<n.inputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Input);
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
dl->AddText(p + ImVec2(10,-5), IM_COL32(200,200,200,255), n.inputs[i].name.c_str());
}
for(size_t i=0; i<n.outputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Output);
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
ImVec2 txt = p - ImVec2(10 + ImGui::CalcTextSize(n.outputs[i].name.c_str()).x, 5);
dl->AddText(txt, IM_COL32(200,200,200,255), n.outputs[i].name.c_str());
}
}
void DrawConnections(ImDrawList* dl, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
for(const auto& c : g.connections) {
auto fn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.fromNode;});
auto tn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.toNode;});
if(fn==g.nodes.end() || tn==g.nodes.end()) continue;
ImVec2 start = fn->GetPortScreenPos(c.fromPort, false, canvasPos, pan);
ImVec2 end = tn->GetPortScreenPos(c.toPort, true, canvasPos, pan);
ImVec2 cp1 = start + ImVec2(60,0);
ImVec2 cp2 = end - ImVec2(60,0);
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(160,160,160,150), 2.0f, 24);
}
if(g.creatingConn) {
auto sn = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.connStartNode;});
if(sn != g.nodes.end()) {
ImVec2 start = (g.connStartType==PortType::Output)
? sn->GetPortScreenPos(g.connStartPort, false, canvasPos, pan)
: sn->GetPortScreenPos(g.connStartPort, true, canvasPos, pan);
ImVec2 end = g.connMousePos;
ImVec2 cp1 = start + ImVec2(60,0);
ImVec2 cp2 = end - ImVec2(60,0);
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(255,255,100,180), 2.5f, 24);
}
}
}
void HandleInput(GraphState& g, const ImVec2& canvasPos, const ImVec2& canvasSize) {
ImGuiIO& io = ImGui::GetIO();
ImVec2 mouseScreen = io.MousePos;
ImVec2 mouseWorld = mouseScreen - canvasPos - g.pan;
if(ImGui::IsMouseDown(ImGuiMouseButton_Middle)) {
if(!g.panning) { g.panning=true; g.panStart=mouseScreen-g.pan; }
g.pan = mouseScreen - g.panStart;
} else g.panning=false;
if(ImGui::IsWindowHovered() && io.MouseWheel != 0.0f) {
g.pan.y -= io.MouseWheel * 20.0f;
}
g.hoveredPortNode=-1; g.hoveredPortIdx=-1;
for(const auto& n : g.nodes) {
for(size_t i=0;i<n.inputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Input; }
}
for(size_t i=0;i<n.outputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Output; }
}
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode!=-1 && !g.panning) {
g.creatingConn=true; g.connStartNode=g.hoveredPortNode; g.connStartPort=g.hoveredPortIdx; g.connStartType=g.hoveredPortType;
}
if(g.creatingConn) {
g.connMousePos = mouseScreen;
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) {
if(g.hoveredPortNode!=-1 && g.hoveredPortNode!=g.connStartNode && g.hoveredPortType!=g.connStartType) {
if(g.connStartType==PortType::Output) g.connections.emplace_back(g.connStartNode, g.connStartPort, g.hoveredPortNode, g.hoveredPortIdx);
else g.connections.emplace_back(g.hoveredPortNode, g.hoveredPortIdx, g.connStartNode, g.connStartPort);
}
g.creatingConn=false;
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Right)) g.creatingConn=false;
}
for(auto& n : g.nodes) {
ImVec2 scr = canvasPos + n.pos + g.pan;
if(ImRect(scr, scr+n.size).Contains(mouseScreen) && ImGui::IsMouseClicked(ImGuiMouseButton_Left) && !g.creatingConn) {
n.selected=true; n.dragging=true; g.selectedNode=n.id;
n.dragOffset = mouseWorld - n.pos;
}
if(n.dragging) {
n.pos = mouseWorld - n.dragOffset;
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) n.dragging=false;
}
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode==-1 && !g.creatingConn && !g.panning) {
bool onNode = false;
for(const auto& n : g.nodes) if(ImRect(canvasPos+n.pos+g.pan, canvasPos+n.pos+g.pan+n.size).Contains(mouseScreen)) onNode=true;
if(!onNode) { g.selectedNode=-1; for(auto& n:g.nodes) n.selected=false; }
}
}
void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
ImDrawList* dl = ImGui::GetWindowDrawList();
ImVec2 canvasPos = ImGui::GetCursorScreenPos();
dl->AddRectFilled(canvasPos, canvasPos+canvasSize, IM_COL32(25,27,35,255));
float gs = 40.0f;
ImVec2 off = ImVec2(fmodf(g.pan.x, gs), fmodf(g.pan.y, gs));
for(float x=off.x; x<canvasSize.x; x+=gs) dl->AddLine(canvasPos+ImVec2(x,0), canvasPos+ImVec2(x,canvasSize.y), IM_COL32(40,44,55,120));
for(float y=off.y; y<canvasSize.y; y+=gs) dl->AddLine(canvasPos+ImVec2(0,y), canvasPos+ImVec2(canvasSize.x,y), IM_COL32(40,44,55,120));
DrawConnections(dl, g, canvasPos, g.pan);
for(auto& n : g.nodes) DrawNode(dl, n, g, canvasPos, g.pan);
ImGui::InvisibleButton("##Canvas", canvasSize);
if(ImGui::IsItemHovered()) HandleInput(g, canvasPos, canvasSize);
if(ImGui::BeginPopupContextItem("##Ctx")) {
ImVec2 wPos = ImGui::GetMousePos() - canvasPos - g.pan;
if(ImGui::MenuItem("🟢 Input")) { g.nodes.emplace_back(g.nextId++, "Input", NodeType::Input); g.nodes.back().pos=wPos; }
if(ImGui::MenuItem(" Hidden")) { g.nodes.emplace_back(g.nextId++, "Hidden", NodeType::Hidden); g.nodes.back().pos=wPos; }
if(ImGui::MenuItem(" Output")) { g.nodes.emplace_back(g.nextId++, "Output", NodeType::Output); g.nodes.back().pos=wPos; }
ImGui::Separator();
if(ImGui::MenuItem(" Splitter (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Splitter", NodeType::Splitter); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
if(ImGui::MenuItem(" Merger (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Merger", NodeType::Merger); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
ImGui::Separator(); ImGui::Text("LMB: Drag/Connect | RMB: Cancel | MMB: Pan");
ImGui::EndPopup();
}
if(g.selectedNode != -1) {
auto it = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.selectedNode;});
if(it != g.nodes.end()) {
Node& sel = *it;
ImGui::SetNextWindowPos(canvasPos + ImVec2(10,10));
ImGui::SetNextWindowSize(ImVec2(240,220));
if(ImGui::Begin("##Props", nullptr, ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_AlwaysAutoResize|ImGuiWindowFlags_NoMove)) {
ImGui::TextColored(ImVec4(1,1,0.5f,1), "%s # %d", sel.title.c_str(), sel.id); ImGui::Separator();
if(sel.type != NodeType::Output) { ImGui::InputInt("Neurons", &sel.layerSize); if(sel.layerSize<1) sel.layerSize=1; }
if(sel.type == NodeType::Splitter || sel.type == NodeType::Merger) {
ImGui::SliderInt("Branches", &sel.branchCount, 2, 8);
if(ImGui::IsItemDeactivatedAfterEdit()) sel.UpdatePorts();
} else if(sel.type == NodeType::Input) {
ImGui::Text("Branch:");
ImGui::RadioButton("Combined", &sel.activeBranch, -1);
ImGui::RadioButton("A (0)", &sel.activeBranch, 0);
ImGui::RadioButton("B (1)", &sel.activeBranch, 1);
}
ImGui::Spacing();
if(ImGui::Button("🗑 Delete", ImVec2(-1,28))) {
g.connections.erase(std::remove_if(g.connections.begin(), g.connections.end(), [id=sel.id](const Connection& c){return c.fromNode==id||c.toNode==id;}), g.connections.end());
g.nodes.erase(it); g.selectedNode=-1;
}
ImGui::End();
}
}
}
}
void SyncFromConfigs(GraphState& g, const std::vector<LayerStructure_t>& cfgs) {
g.nodes.clear(); g.connections.clear(); g.nextId=0;
for(size_t i=0;i<cfgs.size();i++) {
auto& c = cfgs[i];
NodeType t = c.sources.empty() ? NodeType::Input : (i==cfgs.size()-1 ? NodeType::Output : NodeType::Hidden);
Node n(g.nextId++, t==NodeType::Input?"Input":t==NodeType::Output?"Output":"Hidden", t);
n.pos = ImVec2(120 + i*260, 150 + (i%3)*120);
n.layerSize = c.size; n.activeBranch = c.branch;
// UpdatePorts вызывается в конструкторе, но для гарантии:
n.UpdatePorts();
g.nodes.push_back(n);
}
for(size_t i=0;i<cfgs.size();i++) {
for(size_t j=0;j<cfgs[i].sources.size();j++)
g.connections.emplace_back(cfgs[i].sources[j], 0, (int)i, 0);
}
}
void SyncToConfigs(GraphState& g, std::vector<LayerStructure_t>& cfgs) {
cfgs.clear();
std::vector<const Node*> sorted;
for(auto& n:g.nodes) if(n.type!=NodeType::Splitter && n.type!=NodeType::Merger) sorted.push_back(&n);
std::sort(sorted.begin(), sorted.end(), [](const Node*a, const Node*b){return a->pos.x < b->pos.x;});
std::map<int,int> idToIdx;
for(size_t i=0;i<sorted.size();i++) idToIdx[sorted[i]->id] = (int)i;
for(const Node* n : sorted) {
LayerStructure_t l; l.size=n->layerSize; l.branch=n->activeBranch;
for(const auto& c : g.connections) {
if(c.toNode==n->id && idToIdx.count(c.fromNode)) {
l.sources.push_back(idToIdx[c.fromNode]);
l.sourceBranches.push_back(n->activeBranch);
}
}
cfgs.push_back(l);
}
}
std::string GenerateArchitectureText(const GraphState& g, const std::vector<LayerStructure_t>& configs) {
std::stringstream ss;
ss << "=== NEURAL NETWORK ARCHITECTURE ===\n\n";
long long totalParams = 0;
for(size_t i=0;i<configs.size();i++) {
if(configs[i].sources.empty()) continue;
int inputSize = 0;
for(int src : configs[i].sources) inputSize += configs[src].size;
long long layerParams = (long long)inputSize * configs[i].size + configs[i].size;
totalParams += layerParams;
}
ss << "Total Parameters: " << totalParams << " (" << std::fixed << std::setprecision(2) << (totalParams/1000000.0) << "M)\n\n";
ss << "Layer Structure:\n";
ss << "----------------\n";
for(size_t i=0;i<configs.size();i++) {
auto& l = configs[i];
std::string type = l.sources.empty() ? "INPUT" : (i==configs.size()-1 ? "OUTPUT" : "HIDDEN");
ss << "Layer " << i << " [" << type << "]: " << l.size << " neurons";
if(l.branch != -1) ss << " (Branch " << (char)('A'+l.branch) << ")";
ss << "\n";
if(!l.sources.empty()) {
ss << " ← From: ";
for(size_t j=0;j<l.sources.size();j++) {
ss << "Layer " << l.sources[j];
if(j < l.sources.size()-1) ss << ", ";
}
ss << "\n";
}
}
ss << "\n=== NODE GRAPH ===\n";
ss << "Total Nodes: " << g.nodes.size() << "\n";
ss << "Total Connections: " << g.connections.size() << "\n\n";
for(const auto& n : g.nodes) {
ss << "Node #" << n.id << " [" << n.title << "] @ ("
<< (int)n.pos.x << "," << (int)n.pos.y << ")\n";
ss << " Size: " << n.layerSize << " neurons\n";
ss << " Inputs: " << n.inputs.size() << " | Outputs: " << n.outputs.size() << "\n";
if(n.type == NodeType::Splitter || n.type == NodeType::Merger)
ss << " Branches: " << n.branchCount << "\n";
}
return ss.str();
}
// === SAVE/LOAD ARCHITECTURE AS TEXT STRUCTURE ===
bool SaveArchitectureToFile(const std::vector<LayerStructure_t>& configs, const std::string& filename) {
std::ofstream file(filename);
if(!file.is_open()) return false;
file << "[ARCHITECTURE]\n";
file << "layers=" << configs.size() << "\n";
for(size_t i=0;i<configs.size();i++) {
auto& l = configs[i];
file << "[LAYER " << i << "]\n";
file << "size=" << l.size << "\n";
file << "branch=" << l.branch << "\n";
file << "sources=";
for(size_t j=0;j<l.sources.size();j++) {
file << l.sources[j];
if(j < l.sources.size()-1) file << ",";
}
file << "\n";
file << "branches=";
for(size_t j=0;j<l.sourceBranches.size();j++) {
file << l.sourceBranches[j];
if(j < l.sourceBranches.size()-1) file << ",";
}
file << "\n\n";
}
file.close();
return true;
}
bool LoadArchitectureFromFile(std::vector<LayerStructure_t>& configs, const std::string& filename) {
std::ifstream file(filename);
if(!file.is_open()) return false;
configs.clear();
std::string line;
int currentLayer = -1;
while(std::getline(file, line)) {
if(line.find("[LAYER ") == 0) {
size_t start = line.find(' ') + 1;
size_t end = line.find(']');
currentLayer = std::stoi(line.substr(start, end - start));
configs.push_back(LayerStructure_t());
} else if(line.find("size=") == 0 && currentLayer >= 0) {
configs[currentLayer].size = std::stoi(line.substr(5));
} else if(line.find("branch=") == 0 && currentLayer >= 0) {
configs[currentLayer].branch = std::stoi(line.substr(7));
} else if(line.find("sources=") == 0 && currentLayer >= 0) {
std::string srcs = line.substr(8);
std::stringstream ss(srcs);
std::string token;
while(std::getline(ss, token, ',')) {
if(!token.empty()) configs[currentLayer].sources.push_back(std::stoi(token));
}
} else if(line.find("branches=") == 0 && currentLayer >= 0) {
std::string brs = line.substr(9);
std::stringstream ss(brs);
std::string token;
while(std::getline(ss, token, ',')) {
if(!token.empty()) configs[currentLayer].sourceBranches.push_back(std::stoi(token));
}
}
}
file.close();
return !configs.empty();
}
} // namespace NodeEditor
// ============================================================================
// CHAT SYSTEM
// ============================================================================
struct ChatMessage {
std::string role;
std::string content;
};
struct ChatSession {
std::string name;
std::vector<ChatMessage> messages;
int id;
ChatSession(int id_, const std::string& name_) : name(name_), id(id_) {}
};
// ============================================================================
// DATASET
// ============================================================================
struct TrainingSample {
std::string userText;
std::string aiText;
};
std::vector<TrainingSample> ParseDataset(const std::string& text, Tokenizer& tok) {
std::vector<TrainingSample> samples;
size_t pos = 0;
while(pos < text.length()) {
size_t userStart = text.find("[USER]", pos);
if(userStart == std::string::npos) break;
userStart += 6;
size_t aiStart = text.find("[AI]", userStart);
if(aiStart == std::string::npos) break;
aiStart += 4;
size_t eosPos = text.find("[EOS]", aiStart);
if(eosPos == std::string::npos) break;
std::string userText = text.substr(userStart, aiStart - userStart - 4);
std::string aiText = text.substr(aiStart, eosPos - aiStart);
while(!userText.empty() && (userText.back()==' ' || userText.back()=='\n' || userText.back()=='\r')) userText.pop_back();
while(!aiText.empty() && (aiText.back()==' ' || aiText.back()=='\n' || aiText.back()=='\r')) aiText.pop_back();
if(!userText.empty() && !aiText.empty()) {
samples.push_back({userText, aiText});
}
pos = eosPos + 5;
}
return samples;
}
bool LoadDatasetFromFile(std::string& buffer, const std::string& filename) {
std::ifstream file(filename);
if(!file.is_open()) return false;
std::stringstream ss;
ss << file.rdbuf();
buffer = ss.str();
file.close();
return true;
}
// ============================================================================
// GLOBAL STATE
// ============================================================================
int UI_CONTEXT = 256, UI_EMBED = 8, UI_VOCAB = 300;
int MAX_RESPONSE_TOKENS = 50;
struct UIState {
std::vector<ChatSession> chats;
int activeChat = 0;
char inputBuf[512] = "";
bool scrollChat = true;
TrainStatus lastStatus;
std::mutex mtx;
float lr = 0.01f;
int epochs = 10;
int doneEpochs = 0;
double tokensPerSecond = 0.0;
double generationTokensPerSecond = 0.0;
int totalTokensTrained = 0;
std::chrono::steady_clock::time_point trainingStartTime;
std::chrono::steady_clock::time_point lastTokenTime;
char dsBuf[524288] = "";
std::vector<TrainingSample> parsedSamples;
bool datasetLoaded = false;
std::atomic<bool> training{false}, stop{false};
std::vector<LayerStructure_t> layers;
NodeEditor::GraphState graph;
std::string architectureText;
bool showArchitecture = true;
char archFilePath[256] = "architecture.txt";
char datasetFilePath[256] = "dataset.txt";
} ui;
// ============================================================================
// NEURAL NETWORK HELPERS
// ============================================================================
std::map<int, std::vector<double>> PrepInput(const std::vector<int>& toks, Embedder& emb) {
std::map<int, std::vector<double>> res;
for(size_t i=0;i<ui.layers.size();i++) {
if(ui.layers[i].sources.empty()) {
std::vector<double> d;
d.reserve(UI_CONTEXT*UI_EMBED);
int sz = (ui.layers[i].branch==-1) ? UI_CONTEXT : UI_CONTEXT/2;
int st = (ui.layers[i].branch==1) ? UI_CONTEXT/2 : 0;
int cnt=0;
for(int j=std::max(0,(int)toks.size()-UI_CONTEXT+st); j<(int)toks.size() && cnt<sz; j++) {
auto v=emb.get(toks[j]);
d.insert(d.end(),v.begin(),v.end());
cnt++;
}
while(cnt++<sz) for(int k=0;k<UI_EMBED;k++) d.push_back(0);
res[i]=d;
}
}
return res;
}
void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
if(ui.parsedSamples.empty()) {
ui.training=false;
return;
}
int os = ui.layers.back().size;
double totalLoss = 0;
int sampleCount = 0;
int totalTokens = 0;
ui.trainingStartTime = std::chrono::steady_clock::now();
ui.totalTokensTrained = 0;
for(int e=0; e<ui.epochs && !ui.stop; e++) {
auto epochStart = std::chrono::steady_clock::now();
for(size_t s=0; s<ui.parsedSamples.size() && !ui.stop; s++) {
auto userToks = tk->textToTokens(ui.parsedSamples[s].userText);
auto aiToks = tk->textToTokens(ui.parsedSamples[s].aiText);
if(userToks.empty() || aiToks.empty()) continue;
std::vector<int> context = userToks;
for(size_t i=0; i<aiToks.size() && !ui.stop; i++) {
auto tokenStart = std::chrono::steady_clock::now();
std::vector<double> target(os, 0);
if(aiToks[i] < os) target[aiToks[i]] = 1.0;
double loss = nn->train(PrepInput(context, *eb), target, ui.lr);
totalLoss += loss;
sampleCount++;
totalTokens++;
ui.totalTokensTrained++;
context.push_back(aiToks[i]);
if((int)context.size() > UI_CONTEXT) context.erase(context.begin());
// Расчет tokens per second
auto tokenEnd = std::chrono::steady_clock::now();
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
if(tokenTime > 0) {
ui.tokensPerSecond = 1.0 / tokenTime;
}
{
std::lock_guard<std::mutex> lk(ui.mtx);
ui.lastStatus.loss = loss;
float totalSteps = ui.epochs * ui.parsedSamples.size();
float currentStep = e * ui.parsedSamples.size() + s;
ui.lastStatus.progress = (currentStep / totalSteps) * 100.0f;
ui.lastStatus.epoch = e + 1;
ui.lastStatus.speed = ui.tokensPerSecond;
}
}
}
auto epochEnd = std::chrono::steady_clock::now();
double epochTime = std::chrono::duration<double>(epochEnd - epochStart).count();
ui.doneEpochs++;
// Логируем каждую эпоху
std::cout << "Epoch " << (e+1) << "/" << ui.epochs
<< " | Loss: " << (totalLoss/sampleCount)
<< " | Time: " << epochTime << "s"
<< " | Tokens/sec: " << ui.tokensPerSecond
<< "\n";
}
if(sampleCount > 0) {
std::lock_guard<std::mutex> lk(ui.mtx);
ui.lastStatus.totalLoss = totalLoss / sampleCount;
// Общая статистика
auto totalTime = std::chrono::duration<double>(
std::chrono::steady_clock::now() - ui.trainingStartTime).count();
if(totalTime > 0) {
ui.tokensPerSecond = ui.totalTokensTrained / totalTime;
}
}
ui.training=false;
ui.stop=false;
}
// ============================================================================
// MAIN
// ============================================================================
int main() {
if(!glfwInit()) return 1;
GLFWwindow* win = glfwCreateWindow(1600,1000,"Xenith Studio",nullptr,nullptr);
glfwMakeContextCurrent(win); glfwSwapInterval(1);
IMGUI_CHECKVERSION(); ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
io.Fonts->AddFontFromFileTTF("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic());
ImGui_ImplGlfw_InitForOpenGL(win, true); ImGui_ImplOpenGL3_Init("#version 130");
ui.layers = {
LayerStructure_t(UI_CONTEXT*UI_EMBED,{}),
LayerStructure_t(512,{0}),
LayerStructure_t(UI_VOCAB,{1})
};
ui.layers[0].branch = -1;
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
ui.chats.push_back(ChatSession(0, "Chat 1"));
Tokenizer tok;
Embedder emb(UI_VOCAB, UI_EMBED);
NeuralNetwork* nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
// Try load architecture
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
std::cout << "Loaded architecture from " << ui.archFilePath << "\n";
}
while(!glfwWindowShouldClose(win)) {
glfwPollEvents();
ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame();
ImGui::SetNextWindowPos(ImVec2(0,0)); ImGui::SetNextWindowSize(io.DisplaySize);
ImGui::Begin("Main", nullptr, ImGuiWindowFlags_NoDecoration|ImGuiWindowFlags_MenuBar|ImGuiWindowFlags_NoMove|ImGuiWindowFlags_NoResize);
if(ImGui::BeginMenuBar()) {
if(ImGui::BeginMenu("File")) {
if(ImGui::MenuItem("Apply & Save Arch")) {
NodeEditor::SyncToConfigs(ui.graph, ui.layers);
NodeEditor::SaveArchitectureToFile(ui.layers, ui.archFilePath);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
if(ImGui::MenuItem("Load Arch")) {
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
}
ImGui::EndMenu();
}
ImGui::EndMenuBar();
}
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Epoch: %d/%d | Loss: %.5f | %.1f%%",
ui.lastStatus.epoch, ui.epochs, ui.lastStatus.loss, ui.lastStatus.progress);
ImGui::SameLine();
ImGui::TextColored(ui.training?ImVec4(0,1,0,1):ImVec4(1,1,0,1),
ui.training?"[TRAINING]":"[IDLE]");
ImGui::ProgressBar(ui.lastStatus.progress/100.0f, ImVec2(-1,18));
}
ImGui::Separator();
ImGui::Columns(3, "MainCols", true);
ImGui::SetColumnWidth(0, io.DisplaySize.x*0.35f);
ImGui::SetColumnWidth(1, io.DisplaySize.x*0.35f);
// === COLUMN 1: NODE EDITOR ===
ImGui::BeginChild("Graph", ImVec2(0,0), true);
ImGui::TextColored(ImVec4(0,1,1,1), "NODE GRAPH EDITOR");
ImGui::Separator();
NodeEditor::DrawGraph(ui.graph, ImGui::GetContentRegionAvail());
ImGui::EndChild();
// === COLUMN 2: ARCHITECTURE & SETTINGS ===
ImGui::NextColumn();
ImGui::BeginChild("Info", ImVec2(0,0), true);
if(ImGui::BeginTabBar("InfoTabs")) {
if(ImGui::BeginTabItem("Architecture")) {
if(ImGui::Button("Refresh")) {
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
ImGui::SameLine();
ImGui::InputText("Arch File", ui.archFilePath, 256);
ImGui::Separator();
ImGui::BeginChild("ArchText", ImVec2(0,0), true);
ImGui::TextWrapped("%s", ui.architectureText.c_str());
ImGui::EndChild();
ImGui::EndTabItem();
}
if(ImGui::BeginTabItem("Training")) {
ImGui::TextColored(ImVec4(0,1,1,1), "Performance Metrics:");
ImGui::Separator();
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Training Speed: %.2f tokens/sec", ui.tokensPerSecond);
ImGui::Text("Generation Speed: %.2f tokens/sec", ui.generationTokensPerSecond);
ImGui::Text("Total Tokens Trained: %d", ui.totalTokensTrained);
}
ImGui::Separator();
ImGui::SliderFloat("Learning Rate", &ui.lr, 0.0001f, 0.1f, "%.5f");
ImGui::InputInt("Epochs", &ui.epochs); if(ui.epochs<1) ui.epochs=1;
ImGui::InputInt("Max Response Tokens", &MAX_RESPONSE_TOKENS); if(MAX_RESPONSE_TOKENS<1) MAX_RESPONSE_TOKENS=1;
ImGui::InputInt("Context Size", &UI_CONTEXT); if(UI_CONTEXT<1) UI_CONTEXT=1;
ImGui::InputInt("Embedding Dim", &UI_EMBED); if(UI_EMBED<1) UI_EMBED=1;
ImGui::InputInt("Vocab Size", &UI_VOCAB); if(UI_VOCAB<1) UI_VOCAB=1;
ImGui::Separator();
ImGui::TextColored(ImVec4(1,0.5f,0,1), "⚠️ Recommendations:");
if(ui.parsedSamples.size() < 10) {
ImGui::TextColored(ImVec4(1,0,0,1), "• Too few samples! Add at least 10-20 examples");
}
if(ui.epochs > 50 && ui.parsedSamples.size() < 5) {
ImGui::TextColored(ImVec4(1,0,0,1), "• Too many epochs for small dataset! Reduce to 10-20");
}
if(ui.lr > 0.01f) {
ImGui::TextColored(ImVec4(1,0.8f,0,1), "• High learning rate! Try 0.001-0.01");
}
if(ImGui::Button("Load Dataset from File")) {
std::string tempBuffer; // 1. Создаем переменную для загрузки
if(LoadDatasetFromFile(tempBuffer, ui.datasetFilePath)) {
// 2. Копируем данные из переменной в буфер ImGui
// Используем strncpy, чтобы не переполнить массив ui.dsBuf
strncpy(ui.dsBuf, tempBuffer.c_str(), sizeof(ui.dsBuf) - 1);
ui.dsBuf[sizeof(ui.dsBuf) - 1] = '\0'; // Гарантируем завершение строки нулем
// 3. Парсим обновленный буфер
ui.parsedSamples = ParseDataset(ui.dsBuf, tok);
ui.datasetLoaded = true;
std::cout << "Loaded " << ui.parsedSamples.size() << " samples\n";
} else {
std::cerr << "Failed to load dataset\n";
}
}
ImGui::SameLine();
if(ui.datasetLoaded) ImGui::TextColored(ImVec4(0,1,0,1), "✓ %lu samples", ui.parsedSamples.size());
else ImGui::Text("Not loaded");
ImGui::Separator();
if(!ui.training) {
if(ui.parsedSamples.empty()) {
ImGui::PushStyleVar(ImGuiStyleVar_Alpha, 0.5f);
ImGui::Button("▶ START TRAINING (Load dataset first)", ImVec2(-1,45));
ImGui::PopStyleVar();
} else {
if(ImGui::Button("▶ START TRAINING", ImVec2(-1,45))) {
ui.training=true;
ui.stop=false;
ui.doneEpochs=0;
std::thread(TrainTask, nn, &tok, &emb).detach();
}
}
} else {
ImGui::PushStyleColor(ImGuiCol_Button, ImVec4(0.8f,0.2f,0.2f,1));
if(ImGui::Button("⏹ STOP", ImVec2(-1,45))) ui.stop=true;
ImGui::PopStyleColor();
}
ImGui::EndTabItem();
}
if(ImGui::BeginTabItem("Dataset")) {
if(ImGui::Button("Parse from Text")) {
ui.parsedSamples = ParseDataset(std::string(ui.dsBuf), tok);
ui.datasetLoaded = !ui.parsedSamples.empty();
}
ImGui::SameLine();
ImGui::Text("(%lu samples)", ui.parsedSamples.size());
ImGui::InputTextMultiline("##DS", ui.dsBuf, sizeof(ui.dsBuf), ImVec2(-1,-1),
ImGuiInputTextFlags_AllowTabInput);
ImGui::TextWrapped("Format: [USER]text[AI]response[EOS]");
ImGui::EndTabItem();
}
ImGui::EndTabBar();
}
ImGui::EndChild();
// === COLUMN 3: CHATS ===
ImGui::NextColumn();
ImGui::BeginChild("Chats", ImVec2(0,0), true);
static char newChatName[64] = "";
if(ImGui::Button("+ New Chat")) {
ui.chats.push_back(ChatSession(ui.chats.size(), std::string("Chat ") + std::to_string(ui.chats.size()+1)));
}
ImGui::SameLine();
ImGui::InputText("##NewChatName", newChatName, 64);
for(size_t i=0;i<ui.chats.size();i++) {
if(ImGui::Selectable(ui.chats[i].name.c_str(), (int)i==ui.activeChat)) {
ui.activeChat = i;
}
}
ImGui::Separator();
if(!ui.chats.empty() && ui.activeChat < ui.chats.size()) {
auto& chat = ui.chats[ui.activeChat];
ImGui::Text("%s", chat.name.c_str());
ImGui::Separator();
ImGui::BeginChild("Messages", ImVec2(0, -80), true);
for(auto& msg : chat.messages) {
if(msg.role == "USER") {
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.5f,0.8f,1.0f,1));
ImGui::Text("[You]: %s", msg.content.c_str());
ImGui::PopStyleColor();
} else {
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.8f,1.0f,0.5f,1));
ImGui::TextWrapped("[AI]: %s", msg.content.c_str());
ImGui::PopStyleColor();
}
}
if(ui.scrollChat) { ImGui::SetScrollHereY(1.0f); ui.scrollChat = false; }
ImGui::EndChild();
if(ImGui::InputText("##Input", ui.inputBuf, 512, ImGuiInputTextFlags_EnterReturnsTrue)) {
std::string input = ui.inputBuf;
chat.messages.push_back({"USER", input});
auto ctx = tok.textToTokens(input);
std::string ans;
int generatedTokens = 0;
auto genStartTime = std::chrono::steady_clock::now();
for(int g=0; g<MAX_RESPONSE_TOKENS; g++) {
auto tokenStart = std::chrono::steady_clock::now();
auto out = nn->feedForward(PrepInput(ctx, emb));
int id = std::max_element(out.begin(),out.end())-out.begin();
if(id<=0||id>=UI_VOCAB) break;
std::string w = tok.getWord(id);
if(w.empty()) break;
ans += w + " ";
ctx.push_back(id);
if((int)ctx.size()>UI_CONTEXT) ctx.erase(ctx.begin());
generatedTokens++;
auto tokenEnd = std::chrono::steady_clock::now();
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
if(tokenTime > 0) {
ui.generationTokensPerSecond = 1.0 / tokenTime;
}
}
auto genEndTime = std::chrono::steady_clock::now();
double genTime = std::chrono::duration<double>(genEndTime - genStartTime).count();
// Добавляем информацию о скорости в лог (опционально)
std::cout << "Generated " << generatedTokens << " tokens in "
<< genTime << "s ("
<< (generatedTokens/genTime) << " tok/s)\n";
chat.messages.push_back({"AI", ans});
ui.inputBuf[0]=0;
ui.scrollChat=true;
}
}
ImGui::EndChild();
ImGui::Columns(1);
ImGui::End();
ImGui::Render();
glViewport(0,0,(int)io.DisplaySize.x,(int)io.DisplaySize.y);
glClearColor(0.08f,0.09f,0.12f,1); glClear(GL_COLOR_BUFFER_BIT);
ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
glfwSwapBuffers(win);
}
delete nn;
ImGui_ImplOpenGL3_Shutdown(); ImGui_ImplGlfw_Shutdown(); ImGui::DestroyContext();
glfwTerminate();
return 0;
}