984 lines
41 KiB
C++
984 lines
41 KiB
C++
// ============================================================================
|
|
// Xenith Studio - Node Editor v3.1 (Исправлено обучение + Загрузка из файла + Фикс портов)
|
|
// ============================================================================
|
|
|
|
#define IMGUI_DEFINE_MATH_OPERATORS
|
|
#include "imgui.h"
|
|
#include "imgui_internal.h"
|
|
#include "imgui_impl_glfw.h"
|
|
#include "imgui_impl_opengl3.h"
|
|
#include <GLFW/glfw3.h>
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <map>
|
|
#include <algorithm>
|
|
#include <thread>
|
|
#include <atomic>
|
|
#include <mutex>
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include <iomanip>
|
|
|
|
#include "Xenith/core.hpp"
|
|
#include "Xenith/token/token.hpp"
|
|
|
|
// ============================================================================
|
|
// NODE EDITOR ENGINE
|
|
// ============================================================================
|
|
|
|
namespace NodeEditor {
|
|
|
|
enum class PortType { Input, Output };
|
|
enum class NodeType { Input, Hidden, Output, Splitter, Merger };
|
|
|
|
struct Port {
|
|
std::string name;
|
|
PortType type;
|
|
int index;
|
|
Port(const std::string& n, PortType t, int idx = 0) : name(n), type(t), index(idx) {}
|
|
};
|
|
|
|
struct Node {
|
|
int id;
|
|
std::string title;
|
|
NodeType type;
|
|
ImVec2 pos;
|
|
ImVec2 size;
|
|
bool selected;
|
|
bool dragging;
|
|
ImVec2 dragOffset;
|
|
|
|
int layerSize;
|
|
int branchCount;
|
|
int activeBranch;
|
|
std::vector<Port> inputs;
|
|
std::vector<Port> outputs;
|
|
|
|
Node(int id_, const std::string& t_, NodeType type_)
|
|
: id(id_), title(t_), type(type_), pos(0,0), size(180,100),
|
|
selected(false), dragging(false), layerSize(128), branchCount(2), activeBranch(-1) {
|
|
UpdatePorts();
|
|
}
|
|
|
|
void UpdatePorts() {
|
|
inputs.clear(); outputs.clear();
|
|
switch(type) {
|
|
case NodeType::Input:
|
|
// ИСПРАВЛЕНО: 0 входов, 1 выход
|
|
outputs.push_back(Port("Out", PortType::Output));
|
|
break;
|
|
case NodeType::Hidden:
|
|
// ИСПРАВЛЕНО: 1 вход, 1 выход
|
|
inputs.push_back(Port("In", PortType::Input));
|
|
outputs.push_back(Port("Out", PortType::Output));
|
|
break;
|
|
case NodeType::Output:
|
|
// ИСПРАВЛЕНО: 1 вход, 0 выходов
|
|
inputs.push_back(Port("In", PortType::Input));
|
|
break;
|
|
case NodeType::Splitter:
|
|
inputs.push_back(Port("In", PortType::Input));
|
|
for(int i=0;i<branchCount;i++) outputs.push_back(Port("Br "+std::to_string(i), PortType::Output));
|
|
break;
|
|
case NodeType::Merger:
|
|
for(int i=0;i<branchCount;i++) inputs.push_back(Port("Br "+std::to_string(i), PortType::Input));
|
|
outputs.push_back(Port("Out", PortType::Output));
|
|
break;
|
|
}
|
|
}
|
|
|
|
ImVec2 GetPortScreenPos(int portIdx, bool isInput, const ImVec2& canvasPos, const ImVec2& pan) const {
|
|
float y = pos.y + 30 + portIdx * 22;
|
|
return canvasPos + (isInput ? ImVec2(pos.x - 6, y) : ImVec2(pos.x + size.x + 6, y)) + pan;
|
|
}
|
|
};
|
|
|
|
struct Connection {
|
|
int fromNode, fromPort, toNode, toPort;
|
|
Connection(int fn, int fp, int tn, int tp) : fromNode(fn), fromPort(fp), toNode(tn), toPort(tp) {}
|
|
};
|
|
|
|
struct GraphState {
|
|
std::vector<Node> nodes;
|
|
std::vector<Connection> connections;
|
|
int nextId = 0;
|
|
int selectedNode = -1;
|
|
int hoveredPortNode = -1, hoveredPortIdx = -1;
|
|
PortType hoveredPortType = PortType::Input;
|
|
|
|
bool creatingConn = false;
|
|
int connStartNode = -1, connStartPort = -1;
|
|
PortType connStartType = PortType::Output;
|
|
ImVec2 connMousePos;
|
|
|
|
ImVec2 pan = ImVec2(100, 80);
|
|
bool panning = false;
|
|
ImVec2 panStart;
|
|
};
|
|
|
|
ImU32 GetNodeColor(NodeType type, bool selected) {
|
|
if(selected) return IM_COL32(255,255,100,255);
|
|
switch(type) {
|
|
case NodeType::Input: return IM_COL32(80,200,120,255);
|
|
case NodeType::Hidden: return IM_COL32(80,140,220,255);
|
|
case NodeType::Output: return IM_COL32(220,80,80,255);
|
|
case NodeType::Splitter: return IM_COL32(220,180,50,255);
|
|
case NodeType::Merger: return IM_COL32(180,80,220,255);
|
|
default: return IM_COL32(120,120,120,255);
|
|
}
|
|
}
|
|
|
|
void DrawNode(ImDrawList* dl, const Node& n, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
|
|
ImVec2 scr = canvasPos + n.pos + pan;
|
|
|
|
dl->AddRectFilled(scr, scr + n.size, IM_COL32(35,38,48,255), 6);
|
|
dl->AddRect(scr, scr + n.size, GetNodeColor(n.type, n.selected), 6, 0, 2.0f);
|
|
|
|
dl->AddText(scr + ImVec2(10,4), IM_COL32(255,255,255,255), n.title.c_str());
|
|
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), std::to_string(n.layerSize).c_str());
|
|
|
|
if(n.type == NodeType::Splitter || n.type == NodeType::Merger) {
|
|
dl->AddText(scr + ImVec2(10, n.size.y-18), IM_COL32(255,220,100,255), (" x"+std::to_string(n.branchCount)).c_str());
|
|
}
|
|
|
|
for(size_t i=0; i<n.inputs.size(); i++) {
|
|
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, pan);
|
|
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Input);
|
|
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
|
|
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
|
|
dl->AddText(p + ImVec2(10,-5), IM_COL32(200,200,200,255), n.inputs[i].name.c_str());
|
|
}
|
|
for(size_t i=0; i<n.outputs.size(); i++) {
|
|
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, pan);
|
|
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Output);
|
|
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
|
|
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
|
|
ImVec2 txt = p - ImVec2(10 + ImGui::CalcTextSize(n.outputs[i].name.c_str()).x, 5);
|
|
dl->AddText(txt, IM_COL32(200,200,200,255), n.outputs[i].name.c_str());
|
|
}
|
|
}
|
|
|
|
void DrawConnections(ImDrawList* dl, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
|
|
for(const auto& c : g.connections) {
|
|
auto fn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.fromNode;});
|
|
auto tn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.toNode;});
|
|
if(fn==g.nodes.end() || tn==g.nodes.end()) continue;
|
|
|
|
ImVec2 start = fn->GetPortScreenPos(c.fromPort, false, canvasPos, pan);
|
|
ImVec2 end = tn->GetPortScreenPos(c.toPort, true, canvasPos, pan);
|
|
|
|
ImVec2 cp1 = start + ImVec2(60,0);
|
|
ImVec2 cp2 = end - ImVec2(60,0);
|
|
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(160,160,160,150), 2.0f, 24);
|
|
}
|
|
if(g.creatingConn) {
|
|
auto sn = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.connStartNode;});
|
|
if(sn != g.nodes.end()) {
|
|
ImVec2 start = (g.connStartType==PortType::Output)
|
|
? sn->GetPortScreenPos(g.connStartPort, false, canvasPos, pan)
|
|
: sn->GetPortScreenPos(g.connStartPort, true, canvasPos, pan);
|
|
ImVec2 end = g.connMousePos;
|
|
ImVec2 cp1 = start + ImVec2(60,0);
|
|
ImVec2 cp2 = end - ImVec2(60,0);
|
|
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(255,255,100,180), 2.5f, 24);
|
|
}
|
|
}
|
|
}
|
|
|
|
void HandleInput(GraphState& g, const ImVec2& canvasPos, const ImVec2& canvasSize) {
|
|
ImGuiIO& io = ImGui::GetIO();
|
|
ImVec2 mouseScreen = io.MousePos;
|
|
ImVec2 mouseWorld = mouseScreen - canvasPos - g.pan;
|
|
|
|
if(ImGui::IsMouseDown(ImGuiMouseButton_Middle)) {
|
|
if(!g.panning) { g.panning=true; g.panStart=mouseScreen-g.pan; }
|
|
g.pan = mouseScreen - g.panStart;
|
|
} else g.panning=false;
|
|
|
|
if(ImGui::IsWindowHovered() && io.MouseWheel != 0.0f) {
|
|
g.pan.y -= io.MouseWheel * 20.0f;
|
|
}
|
|
|
|
g.hoveredPortNode=-1; g.hoveredPortIdx=-1;
|
|
for(const auto& n : g.nodes) {
|
|
for(size_t i=0;i<n.inputs.size();i++) {
|
|
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, g.pan);
|
|
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Input; }
|
|
}
|
|
for(size_t i=0;i<n.outputs.size();i++) {
|
|
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, g.pan);
|
|
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Output; }
|
|
}
|
|
}
|
|
|
|
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode!=-1 && !g.panning) {
|
|
g.creatingConn=true; g.connStartNode=g.hoveredPortNode; g.connStartPort=g.hoveredPortIdx; g.connStartType=g.hoveredPortType;
|
|
}
|
|
if(g.creatingConn) {
|
|
g.connMousePos = mouseScreen;
|
|
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) {
|
|
if(g.hoveredPortNode!=-1 && g.hoveredPortNode!=g.connStartNode && g.hoveredPortType!=g.connStartType) {
|
|
if(g.connStartType==PortType::Output) g.connections.emplace_back(g.connStartNode, g.connStartPort, g.hoveredPortNode, g.hoveredPortIdx);
|
|
else g.connections.emplace_back(g.hoveredPortNode, g.hoveredPortIdx, g.connStartNode, g.connStartPort);
|
|
}
|
|
g.creatingConn=false;
|
|
}
|
|
if(ImGui::IsMouseClicked(ImGuiMouseButton_Right)) g.creatingConn=false;
|
|
}
|
|
|
|
for(auto& n : g.nodes) {
|
|
ImVec2 scr = canvasPos + n.pos + g.pan;
|
|
if(ImRect(scr, scr+n.size).Contains(mouseScreen) && ImGui::IsMouseClicked(ImGuiMouseButton_Left) && !g.creatingConn) {
|
|
n.selected=true; n.dragging=true; g.selectedNode=n.id;
|
|
n.dragOffset = mouseWorld - n.pos;
|
|
}
|
|
if(n.dragging) {
|
|
n.pos = mouseWorld - n.dragOffset;
|
|
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) n.dragging=false;
|
|
}
|
|
}
|
|
|
|
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode==-1 && !g.creatingConn && !g.panning) {
|
|
bool onNode = false;
|
|
for(const auto& n : g.nodes) if(ImRect(canvasPos+n.pos+g.pan, canvasPos+n.pos+g.pan+n.size).Contains(mouseScreen)) onNode=true;
|
|
if(!onNode) { g.selectedNode=-1; for(auto& n:g.nodes) n.selected=false; }
|
|
}
|
|
}
|
|
|
|
void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
|
|
ImDrawList* dl = ImGui::GetWindowDrawList();
|
|
ImVec2 canvasPos = ImGui::GetCursorScreenPos();
|
|
|
|
dl->AddRectFilled(canvasPos, canvasPos+canvasSize, IM_COL32(25,27,35,255));
|
|
|
|
float gs = 40.0f;
|
|
ImVec2 off = ImVec2(fmodf(g.pan.x, gs), fmodf(g.pan.y, gs));
|
|
for(float x=off.x; x<canvasSize.x; x+=gs) dl->AddLine(canvasPos+ImVec2(x,0), canvasPos+ImVec2(x,canvasSize.y), IM_COL32(40,44,55,120));
|
|
for(float y=off.y; y<canvasSize.y; y+=gs) dl->AddLine(canvasPos+ImVec2(0,y), canvasPos+ImVec2(canvasSize.x,y), IM_COL32(40,44,55,120));
|
|
|
|
DrawConnections(dl, g, canvasPos, g.pan);
|
|
for(auto& n : g.nodes) DrawNode(dl, n, g, canvasPos, g.pan);
|
|
|
|
ImGui::InvisibleButton("##Canvas", canvasSize);
|
|
if(ImGui::IsItemHovered()) HandleInput(g, canvasPos, canvasSize);
|
|
|
|
if(ImGui::BeginPopupContextItem("##Ctx")) {
|
|
ImVec2 wPos = ImGui::GetMousePos() - canvasPos - g.pan;
|
|
if(ImGui::MenuItem("🟢 Input")) { g.nodes.emplace_back(g.nextId++, "Input", NodeType::Input); g.nodes.back().pos=wPos; }
|
|
if(ImGui::MenuItem(" Hidden")) { g.nodes.emplace_back(g.nextId++, "Hidden", NodeType::Hidden); g.nodes.back().pos=wPos; }
|
|
if(ImGui::MenuItem(" Output")) { g.nodes.emplace_back(g.nextId++, "Output", NodeType::Output); g.nodes.back().pos=wPos; }
|
|
ImGui::Separator();
|
|
if(ImGui::MenuItem(" Splitter (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Splitter", NodeType::Splitter); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
|
|
if(ImGui::MenuItem(" Merger (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Merger", NodeType::Merger); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
|
|
ImGui::Separator(); ImGui::Text("LMB: Drag/Connect | RMB: Cancel | MMB: Pan");
|
|
ImGui::EndPopup();
|
|
}
|
|
|
|
if(g.selectedNode != -1) {
|
|
auto it = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.selectedNode;});
|
|
if(it != g.nodes.end()) {
|
|
Node& sel = *it;
|
|
ImGui::SetNextWindowPos(canvasPos + ImVec2(10,10));
|
|
ImGui::SetNextWindowSize(ImVec2(240,220));
|
|
if(ImGui::Begin("##Props", nullptr, ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_AlwaysAutoResize|ImGuiWindowFlags_NoMove)) {
|
|
ImGui::TextColored(ImVec4(1,1,0.5f,1), "%s # %d", sel.title.c_str(), sel.id); ImGui::Separator();
|
|
|
|
if(sel.type != NodeType::Output) { ImGui::InputInt("Neurons", &sel.layerSize); if(sel.layerSize<1) sel.layerSize=1; }
|
|
|
|
if(sel.type == NodeType::Splitter || sel.type == NodeType::Merger) {
|
|
ImGui::SliderInt("Branches", &sel.branchCount, 2, 8);
|
|
if(ImGui::IsItemDeactivatedAfterEdit()) sel.UpdatePorts();
|
|
} else if(sel.type == NodeType::Input) {
|
|
ImGui::Text("Branch:");
|
|
ImGui::RadioButton("Combined", &sel.activeBranch, -1);
|
|
ImGui::RadioButton("A (0)", &sel.activeBranch, 0);
|
|
ImGui::RadioButton("B (1)", &sel.activeBranch, 1);
|
|
}
|
|
|
|
ImGui::Spacing();
|
|
if(ImGui::Button("🗑 Delete", ImVec2(-1,28))) {
|
|
g.connections.erase(std::remove_if(g.connections.begin(), g.connections.end(), [id=sel.id](const Connection& c){return c.fromNode==id||c.toNode==id;}), g.connections.end());
|
|
g.nodes.erase(it); g.selectedNode=-1;
|
|
}
|
|
ImGui::End();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void SyncFromConfigs(GraphState& g, const std::vector<LayerStructure_t>& cfgs) {
|
|
g.nodes.clear(); g.connections.clear(); g.nextId=0;
|
|
for(size_t i=0;i<cfgs.size();i++) {
|
|
auto& c = cfgs[i];
|
|
NodeType t = c.sources.empty() ? NodeType::Input : (i==cfgs.size()-1 ? NodeType::Output : NodeType::Hidden);
|
|
Node n(g.nextId++, t==NodeType::Input?"Input":t==NodeType::Output?"Output":"Hidden", t);
|
|
n.pos = ImVec2(120 + i*260, 150 + (i%3)*120);
|
|
n.layerSize = c.size; n.activeBranch = c.branch;
|
|
// UpdatePorts вызывается в конструкторе, но для гарантии:
|
|
n.UpdatePorts();
|
|
g.nodes.push_back(n);
|
|
}
|
|
for(size_t i=0;i<cfgs.size();i++) {
|
|
for(size_t j=0;j<cfgs[i].sources.size();j++)
|
|
g.connections.emplace_back(cfgs[i].sources[j], 0, (int)i, 0);
|
|
}
|
|
}
|
|
|
|
void SyncToConfigs(GraphState& g, std::vector<LayerStructure_t>& cfgs) {
|
|
cfgs.clear();
|
|
std::vector<const Node*> sorted;
|
|
for(auto& n:g.nodes) if(n.type!=NodeType::Splitter && n.type!=NodeType::Merger) sorted.push_back(&n);
|
|
std::sort(sorted.begin(), sorted.end(), [](const Node*a, const Node*b){return a->pos.x < b->pos.x;});
|
|
|
|
std::map<int,int> idToIdx;
|
|
for(size_t i=0;i<sorted.size();i++) idToIdx[sorted[i]->id] = (int)i;
|
|
|
|
for(const Node* n : sorted) {
|
|
LayerStructure_t l; l.size=n->layerSize; l.branch=n->activeBranch;
|
|
for(const auto& c : g.connections) {
|
|
if(c.toNode==n->id && idToIdx.count(c.fromNode)) {
|
|
l.sources.push_back(idToIdx[c.fromNode]);
|
|
l.sourceBranches.push_back(n->activeBranch);
|
|
}
|
|
}
|
|
cfgs.push_back(l);
|
|
}
|
|
}
|
|
|
|
std::string GenerateArchitectureText(const GraphState& g, const std::vector<LayerStructure_t>& configs) {
|
|
std::stringstream ss;
|
|
ss << "=== NEURAL NETWORK ARCHITECTURE ===\n\n";
|
|
|
|
long long totalParams = 0;
|
|
for(size_t i=0;i<configs.size();i++) {
|
|
if(configs[i].sources.empty()) continue;
|
|
int inputSize = 0;
|
|
for(int src : configs[i].sources) inputSize += configs[src].size;
|
|
long long layerParams = (long long)inputSize * configs[i].size + configs[i].size;
|
|
totalParams += layerParams;
|
|
}
|
|
|
|
ss << "Total Parameters: " << totalParams << " (" << std::fixed << std::setprecision(2) << (totalParams/1000000.0) << "M)\n\n";
|
|
ss << "Layer Structure:\n";
|
|
ss << "----------------\n";
|
|
|
|
for(size_t i=0;i<configs.size();i++) {
|
|
auto& l = configs[i];
|
|
std::string type = l.sources.empty() ? "INPUT" : (i==configs.size()-1 ? "OUTPUT" : "HIDDEN");
|
|
ss << "Layer " << i << " [" << type << "]: " << l.size << " neurons";
|
|
if(l.branch != -1) ss << " (Branch " << (char)('A'+l.branch) << ")";
|
|
ss << "\n";
|
|
|
|
if(!l.sources.empty()) {
|
|
ss << " ← From: ";
|
|
for(size_t j=0;j<l.sources.size();j++) {
|
|
ss << "Layer " << l.sources[j];
|
|
if(j < l.sources.size()-1) ss << ", ";
|
|
}
|
|
ss << "\n";
|
|
}
|
|
}
|
|
|
|
ss << "\n=== NODE GRAPH ===\n";
|
|
ss << "Total Nodes: " << g.nodes.size() << "\n";
|
|
ss << "Total Connections: " << g.connections.size() << "\n\n";
|
|
|
|
for(const auto& n : g.nodes) {
|
|
ss << "Node #" << n.id << " [" << n.title << "] @ ("
|
|
<< (int)n.pos.x << "," << (int)n.pos.y << ")\n";
|
|
ss << " Size: " << n.layerSize << " neurons\n";
|
|
ss << " Inputs: " << n.inputs.size() << " | Outputs: " << n.outputs.size() << "\n";
|
|
if(n.type == NodeType::Splitter || n.type == NodeType::Merger)
|
|
ss << " Branches: " << n.branchCount << "\n";
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
// === SAVE/LOAD ARCHITECTURE AS TEXT STRUCTURE ===
|
|
bool SaveArchitectureToFile(const std::vector<LayerStructure_t>& configs, const std::string& filename) {
|
|
std::ofstream file(filename);
|
|
if(!file.is_open()) return false;
|
|
|
|
file << "[ARCHITECTURE]\n";
|
|
file << "layers=" << configs.size() << "\n";
|
|
|
|
for(size_t i=0;i<configs.size();i++) {
|
|
auto& l = configs[i];
|
|
file << "[LAYER " << i << "]\n";
|
|
file << "size=" << l.size << "\n";
|
|
file << "branch=" << l.branch << "\n";
|
|
file << "sources=";
|
|
for(size_t j=0;j<l.sources.size();j++) {
|
|
file << l.sources[j];
|
|
if(j < l.sources.size()-1) file << ",";
|
|
}
|
|
file << "\n";
|
|
file << "branches=";
|
|
for(size_t j=0;j<l.sourceBranches.size();j++) {
|
|
file << l.sourceBranches[j];
|
|
if(j < l.sourceBranches.size()-1) file << ",";
|
|
}
|
|
file << "\n\n";
|
|
}
|
|
|
|
file.close();
|
|
return true;
|
|
}
|
|
|
|
bool LoadArchitectureFromFile(std::vector<LayerStructure_t>& configs, const std::string& filename) {
|
|
std::ifstream file(filename);
|
|
if(!file.is_open()) return false;
|
|
|
|
configs.clear();
|
|
std::string line;
|
|
int currentLayer = -1;
|
|
|
|
while(std::getline(file, line)) {
|
|
if(line.find("[LAYER ") == 0) {
|
|
size_t start = line.find(' ') + 1;
|
|
size_t end = line.find(']');
|
|
currentLayer = std::stoi(line.substr(start, end - start));
|
|
configs.push_back(LayerStructure_t());
|
|
} else if(line.find("size=") == 0 && currentLayer >= 0) {
|
|
configs[currentLayer].size = std::stoi(line.substr(5));
|
|
} else if(line.find("branch=") == 0 && currentLayer >= 0) {
|
|
configs[currentLayer].branch = std::stoi(line.substr(7));
|
|
} else if(line.find("sources=") == 0 && currentLayer >= 0) {
|
|
std::string srcs = line.substr(8);
|
|
std::stringstream ss(srcs);
|
|
std::string token;
|
|
while(std::getline(ss, token, ',')) {
|
|
if(!token.empty()) configs[currentLayer].sources.push_back(std::stoi(token));
|
|
}
|
|
} else if(line.find("branches=") == 0 && currentLayer >= 0) {
|
|
std::string brs = line.substr(9);
|
|
std::stringstream ss(brs);
|
|
std::string token;
|
|
while(std::getline(ss, token, ',')) {
|
|
if(!token.empty()) configs[currentLayer].sourceBranches.push_back(std::stoi(token));
|
|
}
|
|
}
|
|
}
|
|
|
|
file.close();
|
|
return !configs.empty();
|
|
}
|
|
|
|
} // namespace NodeEditor
|
|
|
|
// ============================================================================
|
|
// CHAT SYSTEM
|
|
// ============================================================================
|
|
|
|
struct ChatMessage {
|
|
std::string role;
|
|
std::string content;
|
|
};
|
|
|
|
struct ChatSession {
|
|
std::string name;
|
|
std::vector<ChatMessage> messages;
|
|
int id;
|
|
ChatSession(int id_, const std::string& name_) : name(name_), id(id_) {}
|
|
};
|
|
|
|
// ============================================================================
|
|
// DATASET
|
|
// ============================================================================
|
|
|
|
struct TrainingSample {
|
|
std::string userText;
|
|
std::string aiText;
|
|
};
|
|
|
|
std::vector<TrainingSample> ParseDataset(const std::string& text, Tokenizer& tok) {
|
|
std::vector<TrainingSample> samples;
|
|
size_t pos = 0;
|
|
while(pos < text.length()) {
|
|
size_t userStart = text.find("[USER]", pos);
|
|
if(userStart == std::string::npos) break;
|
|
userStart += 6;
|
|
|
|
size_t aiStart = text.find("[AI]", userStart);
|
|
if(aiStart == std::string::npos) break;
|
|
aiStart += 4;
|
|
|
|
size_t eosPos = text.find("[EOS]", aiStart);
|
|
if(eosPos == std::string::npos) break;
|
|
|
|
std::string userText = text.substr(userStart, aiStart - userStart - 4);
|
|
std::string aiText = text.substr(aiStart, eosPos - aiStart);
|
|
|
|
while(!userText.empty() && (userText.back()==' ' || userText.back()=='\n' || userText.back()=='\r')) userText.pop_back();
|
|
while(!aiText.empty() && (aiText.back()==' ' || aiText.back()=='\n' || aiText.back()=='\r')) aiText.pop_back();
|
|
|
|
if(!userText.empty() && !aiText.empty()) {
|
|
samples.push_back({userText, aiText});
|
|
}
|
|
pos = eosPos + 5;
|
|
}
|
|
return samples;
|
|
}
|
|
|
|
bool LoadDatasetFromFile(std::string& buffer, const std::string& filename) {
|
|
std::ifstream file(filename);
|
|
if(!file.is_open()) return false;
|
|
std::stringstream ss;
|
|
ss << file.rdbuf();
|
|
buffer = ss.str();
|
|
file.close();
|
|
return true;
|
|
}
|
|
|
|
// ============================================================================
|
|
// GLOBAL STATE
|
|
// ============================================================================
|
|
|
|
int UI_CONTEXT = 256, UI_EMBED = 8, UI_VOCAB = 300;
|
|
int MAX_RESPONSE_TOKENS = 50;
|
|
|
|
struct UIState {
|
|
std::vector<ChatSession> chats;
|
|
int activeChat = 0;
|
|
char inputBuf[512] = "";
|
|
bool scrollChat = true;
|
|
|
|
TrainStatus lastStatus;
|
|
std::mutex mtx;
|
|
float lr = 0.01f;
|
|
int epochs = 10;
|
|
int doneEpochs = 0;
|
|
|
|
double tokensPerSecond = 0.0;
|
|
double generationTokensPerSecond = 0.0;
|
|
int totalTokensTrained = 0;
|
|
std::chrono::steady_clock::time_point trainingStartTime;
|
|
std::chrono::steady_clock::time_point lastTokenTime;
|
|
|
|
char dsBuf[524288] = "";
|
|
std::vector<TrainingSample> parsedSamples;
|
|
bool datasetLoaded = false;
|
|
|
|
std::atomic<bool> training{false}, stop{false};
|
|
std::vector<LayerStructure_t> layers;
|
|
NodeEditor::GraphState graph;
|
|
|
|
std::string architectureText;
|
|
bool showArchitecture = true;
|
|
|
|
char archFilePath[256] = "architecture.txt";
|
|
char datasetFilePath[256] = "dataset.txt";
|
|
} ui;
|
|
|
|
// ============================================================================
|
|
// NEURAL NETWORK HELPERS
|
|
// ============================================================================
|
|
|
|
std::map<int, std::vector<double>> PrepInput(const std::vector<int>& toks, Embedder& emb) {
|
|
std::map<int, std::vector<double>> res;
|
|
for(size_t i=0;i<ui.layers.size();i++) {
|
|
if(ui.layers[i].sources.empty()) {
|
|
std::vector<double> d;
|
|
d.reserve(UI_CONTEXT*UI_EMBED);
|
|
int sz = (ui.layers[i].branch==-1) ? UI_CONTEXT : UI_CONTEXT/2;
|
|
int st = (ui.layers[i].branch==1) ? UI_CONTEXT/2 : 0;
|
|
int cnt=0;
|
|
|
|
for(int j=std::max(0,(int)toks.size()-UI_CONTEXT+st); j<(int)toks.size() && cnt<sz; j++) {
|
|
auto v=emb.get(toks[j]);
|
|
d.insert(d.end(),v.begin(),v.end());
|
|
cnt++;
|
|
}
|
|
while(cnt++<sz) for(int k=0;k<UI_EMBED;k++) d.push_back(0);
|
|
res[i]=d;
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
|
|
if(ui.parsedSamples.empty()) {
|
|
ui.training=false;
|
|
return;
|
|
}
|
|
|
|
int os = ui.layers.back().size;
|
|
double totalLoss = 0;
|
|
int sampleCount = 0;
|
|
int totalTokens = 0;
|
|
|
|
ui.trainingStartTime = std::chrono::steady_clock::now();
|
|
ui.totalTokensTrained = 0;
|
|
|
|
for(int e=0; e<ui.epochs && !ui.stop; e++) {
|
|
auto epochStart = std::chrono::steady_clock::now();
|
|
|
|
for(size_t s=0; s<ui.parsedSamples.size() && !ui.stop; s++) {
|
|
auto userToks = tk->textToTokens(ui.parsedSamples[s].userText);
|
|
auto aiToks = tk->textToTokens(ui.parsedSamples[s].aiText);
|
|
|
|
if(userToks.empty() || aiToks.empty()) continue;
|
|
|
|
std::vector<int> context = userToks;
|
|
for(size_t i=0; i<aiToks.size() && !ui.stop; i++) {
|
|
auto tokenStart = std::chrono::steady_clock::now();
|
|
|
|
std::vector<double> target(os, 0);
|
|
if(aiToks[i] < os) target[aiToks[i]] = 1.0;
|
|
|
|
double loss = nn->train(PrepInput(context, *eb), target, ui.lr);
|
|
totalLoss += loss;
|
|
sampleCount++;
|
|
totalTokens++;
|
|
ui.totalTokensTrained++;
|
|
|
|
context.push_back(aiToks[i]);
|
|
if((int)context.size() > UI_CONTEXT) context.erase(context.begin());
|
|
|
|
// Расчет tokens per second
|
|
auto tokenEnd = std::chrono::steady_clock::now();
|
|
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
|
|
if(tokenTime > 0) {
|
|
ui.tokensPerSecond = 1.0 / tokenTime;
|
|
}
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lk(ui.mtx);
|
|
ui.lastStatus.loss = loss;
|
|
float totalSteps = ui.epochs * ui.parsedSamples.size();
|
|
float currentStep = e * ui.parsedSamples.size() + s;
|
|
ui.lastStatus.progress = (currentStep / totalSteps) * 100.0f;
|
|
ui.lastStatus.epoch = e + 1;
|
|
ui.lastStatus.speed = ui.tokensPerSecond;
|
|
}
|
|
}
|
|
}
|
|
|
|
auto epochEnd = std::chrono::steady_clock::now();
|
|
double epochTime = std::chrono::duration<double>(epochEnd - epochStart).count();
|
|
|
|
ui.doneEpochs++;
|
|
|
|
// Логируем каждую эпоху
|
|
std::cout << "Epoch " << (e+1) << "/" << ui.epochs
|
|
<< " | Loss: " << (totalLoss/sampleCount)
|
|
<< " | Time: " << epochTime << "s"
|
|
<< " | Tokens/sec: " << ui.tokensPerSecond
|
|
<< "\n";
|
|
}
|
|
|
|
if(sampleCount > 0) {
|
|
std::lock_guard<std::mutex> lk(ui.mtx);
|
|
ui.lastStatus.totalLoss = totalLoss / sampleCount;
|
|
|
|
// Общая статистика
|
|
auto totalTime = std::chrono::duration<double>(
|
|
std::chrono::steady_clock::now() - ui.trainingStartTime).count();
|
|
if(totalTime > 0) {
|
|
ui.tokensPerSecond = ui.totalTokensTrained / totalTime;
|
|
}
|
|
}
|
|
|
|
ui.training=false;
|
|
ui.stop=false;
|
|
}
|
|
// ============================================================================
|
|
// MAIN
|
|
// ============================================================================
|
|
|
|
int main() {
|
|
if(!glfwInit()) return 1;
|
|
GLFWwindow* win = glfwCreateWindow(1600,1000,"Xenith Studio",nullptr,nullptr);
|
|
glfwMakeContextCurrent(win); glfwSwapInterval(1);
|
|
|
|
IMGUI_CHECKVERSION(); ImGui::CreateContext();
|
|
ImGuiIO& io = ImGui::GetIO();
|
|
io.Fonts->AddFontFromFileTTF("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic());
|
|
ImGui_ImplGlfw_InitForOpenGL(win, true); ImGui_ImplOpenGL3_Init("#version 130");
|
|
|
|
ui.layers = {
|
|
LayerStructure_t(UI_CONTEXT*UI_EMBED,{}),
|
|
LayerStructure_t(512,{0}),
|
|
LayerStructure_t(UI_VOCAB,{1})
|
|
};
|
|
ui.layers[0].branch = -1;
|
|
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
|
|
ui.chats.push_back(ChatSession(0, "Chat 1"));
|
|
|
|
Tokenizer tok;
|
|
Embedder emb(UI_VOCAB, UI_EMBED);
|
|
NeuralNetwork* nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
|
|
|
|
// Try load architecture
|
|
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
|
|
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
|
|
delete nn;
|
|
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
|
|
std::cout << "Loaded architecture from " << ui.archFilePath << "\n";
|
|
}
|
|
|
|
while(!glfwWindowShouldClose(win)) {
|
|
glfwPollEvents();
|
|
ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame();
|
|
|
|
ImGui::SetNextWindowPos(ImVec2(0,0)); ImGui::SetNextWindowSize(io.DisplaySize);
|
|
ImGui::Begin("Main", nullptr, ImGuiWindowFlags_NoDecoration|ImGuiWindowFlags_MenuBar|ImGuiWindowFlags_NoMove|ImGuiWindowFlags_NoResize);
|
|
|
|
if(ImGui::BeginMenuBar()) {
|
|
if(ImGui::BeginMenu("File")) {
|
|
if(ImGui::MenuItem("Apply & Save Arch")) {
|
|
NodeEditor::SyncToConfigs(ui.graph, ui.layers);
|
|
NodeEditor::SaveArchitectureToFile(ui.layers, ui.archFilePath);
|
|
delete nn;
|
|
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
|
|
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
|
|
}
|
|
if(ImGui::MenuItem("Load Arch")) {
|
|
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
|
|
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
|
|
delete nn;
|
|
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
|
|
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
|
|
}
|
|
}
|
|
ImGui::EndMenu();
|
|
}
|
|
ImGui::EndMenuBar();
|
|
}
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lk(ui.mtx);
|
|
ImGui::Text("Epoch: %d/%d | Loss: %.5f | %.1f%%",
|
|
ui.lastStatus.epoch, ui.epochs, ui.lastStatus.loss, ui.lastStatus.progress);
|
|
ImGui::SameLine();
|
|
ImGui::TextColored(ui.training?ImVec4(0,1,0,1):ImVec4(1,1,0,1),
|
|
ui.training?"[TRAINING]":"[IDLE]");
|
|
ImGui::ProgressBar(ui.lastStatus.progress/100.0f, ImVec2(-1,18));
|
|
}
|
|
ImGui::Separator();
|
|
|
|
ImGui::Columns(3, "MainCols", true);
|
|
ImGui::SetColumnWidth(0, io.DisplaySize.x*0.35f);
|
|
ImGui::SetColumnWidth(1, io.DisplaySize.x*0.35f);
|
|
|
|
// === COLUMN 1: NODE EDITOR ===
|
|
ImGui::BeginChild("Graph", ImVec2(0,0), true);
|
|
ImGui::TextColored(ImVec4(0,1,1,1), "NODE GRAPH EDITOR");
|
|
ImGui::Separator();
|
|
NodeEditor::DrawGraph(ui.graph, ImGui::GetContentRegionAvail());
|
|
ImGui::EndChild();
|
|
|
|
// === COLUMN 2: ARCHITECTURE & SETTINGS ===
|
|
ImGui::NextColumn();
|
|
ImGui::BeginChild("Info", ImVec2(0,0), true);
|
|
|
|
if(ImGui::BeginTabBar("InfoTabs")) {
|
|
if(ImGui::BeginTabItem("Architecture")) {
|
|
if(ImGui::Button("Refresh")) {
|
|
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
|
|
}
|
|
ImGui::SameLine();
|
|
ImGui::InputText("Arch File", ui.archFilePath, 256);
|
|
|
|
ImGui::Separator();
|
|
ImGui::BeginChild("ArchText", ImVec2(0,0), true);
|
|
ImGui::TextWrapped("%s", ui.architectureText.c_str());
|
|
ImGui::EndChild();
|
|
ImGui::EndTabItem();
|
|
}
|
|
|
|
if(ImGui::BeginTabItem("Training")) {
|
|
ImGui::TextColored(ImVec4(0,1,1,1), "Performance Metrics:");
|
|
ImGui::Separator();
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lk(ui.mtx);
|
|
ImGui::Text("Training Speed: %.2f tokens/sec", ui.tokensPerSecond);
|
|
ImGui::Text("Generation Speed: %.2f tokens/sec", ui.generationTokensPerSecond);
|
|
ImGui::Text("Total Tokens Trained: %d", ui.totalTokensTrained);
|
|
}
|
|
|
|
ImGui::Separator();
|
|
|
|
ImGui::SliderFloat("Learning Rate", &ui.lr, 0.0001f, 0.1f, "%.5f");
|
|
ImGui::InputInt("Epochs", &ui.epochs); if(ui.epochs<1) ui.epochs=1;
|
|
ImGui::InputInt("Max Response Tokens", &MAX_RESPONSE_TOKENS); if(MAX_RESPONSE_TOKENS<1) MAX_RESPONSE_TOKENS=1;
|
|
ImGui::InputInt("Context Size", &UI_CONTEXT); if(UI_CONTEXT<1) UI_CONTEXT=1;
|
|
ImGui::InputInt("Embedding Dim", &UI_EMBED); if(UI_EMBED<1) UI_EMBED=1;
|
|
ImGui::InputInt("Vocab Size", &UI_VOCAB); if(UI_VOCAB<1) UI_VOCAB=1;
|
|
|
|
ImGui::Separator();
|
|
ImGui::TextColored(ImVec4(1,0.5f,0,1), "⚠️ Recommendations:");
|
|
if(ui.parsedSamples.size() < 10) {
|
|
ImGui::TextColored(ImVec4(1,0,0,1), "• Too few samples! Add at least 10-20 examples");
|
|
}
|
|
if(ui.epochs > 50 && ui.parsedSamples.size() < 5) {
|
|
ImGui::TextColored(ImVec4(1,0,0,1), "• Too many epochs for small dataset! Reduce to 10-20");
|
|
}
|
|
if(ui.lr > 0.01f) {
|
|
ImGui::TextColored(ImVec4(1,0.8f,0,1), "• High learning rate! Try 0.001-0.01");
|
|
}
|
|
if(ImGui::Button("Load Dataset from File")) {
|
|
std::string tempBuffer; // 1. Создаем переменную для загрузки
|
|
if(LoadDatasetFromFile(tempBuffer, ui.datasetFilePath)) {
|
|
// 2. Копируем данные из переменной в буфер ImGui
|
|
// Используем strncpy, чтобы не переполнить массив ui.dsBuf
|
|
strncpy(ui.dsBuf, tempBuffer.c_str(), sizeof(ui.dsBuf) - 1);
|
|
ui.dsBuf[sizeof(ui.dsBuf) - 1] = '\0'; // Гарантируем завершение строки нулем
|
|
|
|
// 3. Парсим обновленный буфер
|
|
ui.parsedSamples = ParseDataset(ui.dsBuf, tok);
|
|
ui.datasetLoaded = true;
|
|
std::cout << "Loaded " << ui.parsedSamples.size() << " samples\n";
|
|
} else {
|
|
std::cerr << "Failed to load dataset\n";
|
|
}
|
|
}
|
|
ImGui::SameLine();
|
|
if(ui.datasetLoaded) ImGui::TextColored(ImVec4(0,1,0,1), "✓ %lu samples", ui.parsedSamples.size());
|
|
else ImGui::Text("Not loaded");
|
|
|
|
ImGui::Separator();
|
|
if(!ui.training) {
|
|
if(ui.parsedSamples.empty()) {
|
|
ImGui::PushStyleVar(ImGuiStyleVar_Alpha, 0.5f);
|
|
ImGui::Button("▶ START TRAINING (Load dataset first)", ImVec2(-1,45));
|
|
ImGui::PopStyleVar();
|
|
} else {
|
|
if(ImGui::Button("▶ START TRAINING", ImVec2(-1,45))) {
|
|
ui.training=true;
|
|
ui.stop=false;
|
|
ui.doneEpochs=0;
|
|
std::thread(TrainTask, nn, &tok, &emb).detach();
|
|
}
|
|
}
|
|
} else {
|
|
ImGui::PushStyleColor(ImGuiCol_Button, ImVec4(0.8f,0.2f,0.2f,1));
|
|
if(ImGui::Button("⏹ STOP", ImVec2(-1,45))) ui.stop=true;
|
|
ImGui::PopStyleColor();
|
|
}
|
|
ImGui::EndTabItem();
|
|
}
|
|
|
|
if(ImGui::BeginTabItem("Dataset")) {
|
|
if(ImGui::Button("Parse from Text")) {
|
|
ui.parsedSamples = ParseDataset(std::string(ui.dsBuf), tok);
|
|
ui.datasetLoaded = !ui.parsedSamples.empty();
|
|
}
|
|
ImGui::SameLine();
|
|
ImGui::Text("(%lu samples)", ui.parsedSamples.size());
|
|
|
|
ImGui::InputTextMultiline("##DS", ui.dsBuf, sizeof(ui.dsBuf), ImVec2(-1,-1),
|
|
ImGuiInputTextFlags_AllowTabInput);
|
|
ImGui::TextWrapped("Format: [USER]text[AI]response[EOS]");
|
|
ImGui::EndTabItem();
|
|
}
|
|
|
|
ImGui::EndTabBar();
|
|
}
|
|
ImGui::EndChild();
|
|
|
|
// === COLUMN 3: CHATS ===
|
|
ImGui::NextColumn();
|
|
ImGui::BeginChild("Chats", ImVec2(0,0), true);
|
|
|
|
static char newChatName[64] = "";
|
|
if(ImGui::Button("+ New Chat")) {
|
|
ui.chats.push_back(ChatSession(ui.chats.size(), std::string("Chat ") + std::to_string(ui.chats.size()+1)));
|
|
}
|
|
ImGui::SameLine();
|
|
ImGui::InputText("##NewChatName", newChatName, 64);
|
|
|
|
for(size_t i=0;i<ui.chats.size();i++) {
|
|
if(ImGui::Selectable(ui.chats[i].name.c_str(), (int)i==ui.activeChat)) {
|
|
ui.activeChat = i;
|
|
}
|
|
}
|
|
|
|
ImGui::Separator();
|
|
|
|
if(!ui.chats.empty() && ui.activeChat < ui.chats.size()) {
|
|
auto& chat = ui.chats[ui.activeChat];
|
|
ImGui::Text("%s", chat.name.c_str());
|
|
ImGui::Separator();
|
|
|
|
ImGui::BeginChild("Messages", ImVec2(0, -80), true);
|
|
for(auto& msg : chat.messages) {
|
|
if(msg.role == "USER") {
|
|
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.5f,0.8f,1.0f,1));
|
|
ImGui::Text("[You]: %s", msg.content.c_str());
|
|
ImGui::PopStyleColor();
|
|
} else {
|
|
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.8f,1.0f,0.5f,1));
|
|
ImGui::TextWrapped("[AI]: %s", msg.content.c_str());
|
|
ImGui::PopStyleColor();
|
|
}
|
|
}
|
|
if(ui.scrollChat) { ImGui::SetScrollHereY(1.0f); ui.scrollChat = false; }
|
|
ImGui::EndChild();
|
|
|
|
if(ImGui::InputText("##Input", ui.inputBuf, 512, ImGuiInputTextFlags_EnterReturnsTrue)) {
|
|
std::string input = ui.inputBuf;
|
|
chat.messages.push_back({"USER", input});
|
|
|
|
auto ctx = tok.textToTokens(input);
|
|
std::string ans;
|
|
int generatedTokens = 0;
|
|
|
|
auto genStartTime = std::chrono::steady_clock::now();
|
|
|
|
for(int g=0; g<MAX_RESPONSE_TOKENS; g++) {
|
|
auto tokenStart = std::chrono::steady_clock::now();
|
|
|
|
auto out = nn->feedForward(PrepInput(ctx, emb));
|
|
int id = std::max_element(out.begin(),out.end())-out.begin();
|
|
if(id<=0||id>=UI_VOCAB) break;
|
|
std::string w = tok.getWord(id);
|
|
if(w.empty()) break;
|
|
ans += w + " ";
|
|
ctx.push_back(id);
|
|
if((int)ctx.size()>UI_CONTEXT) ctx.erase(ctx.begin());
|
|
generatedTokens++;
|
|
|
|
auto tokenEnd = std::chrono::steady_clock::now();
|
|
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
|
|
if(tokenTime > 0) {
|
|
ui.generationTokensPerSecond = 1.0 / tokenTime;
|
|
}
|
|
}
|
|
|
|
auto genEndTime = std::chrono::steady_clock::now();
|
|
double genTime = std::chrono::duration<double>(genEndTime - genStartTime).count();
|
|
|
|
// Добавляем информацию о скорости в лог (опционально)
|
|
std::cout << "Generated " << generatedTokens << " tokens in "
|
|
<< genTime << "s ("
|
|
<< (generatedTokens/genTime) << " tok/s)\n";
|
|
|
|
chat.messages.push_back({"AI", ans});
|
|
ui.inputBuf[0]=0;
|
|
ui.scrollChat=true;
|
|
}
|
|
|
|
}
|
|
|
|
ImGui::EndChild();
|
|
ImGui::Columns(1);
|
|
ImGui::End();
|
|
|
|
ImGui::Render();
|
|
glViewport(0,0,(int)io.DisplaySize.x,(int)io.DisplaySize.y);
|
|
glClearColor(0.08f,0.09f,0.12f,1); glClear(GL_COLOR_BUFFER_BIT);
|
|
ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
|
|
glfwSwapBuffers(win);
|
|
}
|
|
|
|
delete nn;
|
|
ImGui_ImplOpenGL3_Shutdown(); ImGui_ImplGlfw_Shutdown(); ImGui::DestroyContext();
|
|
glfwTerminate();
|
|
return 0;
|
|
} |