This commit is contained in:
2026-05-16 01:43:15 +07:00
parent ea467f8298
commit 91b05aae5d
8 changed files with 210 additions and 95 deletions
+173 -60
View File
@@ -1,5 +1,5 @@
// ============================================================================
// Xenith Studio - Node Editor v3.1 (Исправлено обучение + Загрузка из файла + Фикс портов)
// Xenith AI Studio v3.2
// ============================================================================
#define IMGUI_DEFINE_MATH_OPERATORS
@@ -21,12 +21,23 @@
#include <cmath>
#include <sstream>
#include <iomanip>
#include <chrono>
#include <cstring>
#include "Xenith/core.hpp"
#include "Xenith/token/token.hpp"
// ============================================================================
// NODE EDITOR ENGINE
// ⚠️ ГЛОБАЛЬНЫЕ НАСТРОЙКИ - ДОЛЖНЫ БЫТЬ ПЕРЕД ВСЕМИ ФУНКЦИЯМИ КОТОРЫЕ ИХ ИСПОЛЬЗУЮТ
// ============================================================================
int UI_CONTEXT = 512;
int UI_EMBED = 128;
int UI_VOCAB = 300;
int MAX_RESPONSE_TOKENS = 128;
// ============================================================================
// NODE EDITOR ENGINE (теперь видит UI_CONTEXT, UI_EMBED, UI_VOCAB)
// ============================================================================
namespace NodeEditor {
@@ -54,12 +65,14 @@ struct Node {
int layerSize;
int branchCount;
int activeBranch;
bool isSizeFixed; // NEW: для Input/Output
std::vector<Port> inputs;
std::vector<Port> outputs;
Node(int id_, const std::string& t_, NodeType type_)
: id(id_), title(t_), type(type_), pos(0,0), size(180,100),
selected(false), dragging(false), layerSize(128), branchCount(2), activeBranch(-1) {
selected(false), dragging(false), layerSize(128), branchCount(2),
activeBranch(-1), isSizeFixed(false) {
UpdatePorts();
}
@@ -67,25 +80,32 @@ struct Node {
inputs.clear(); outputs.clear();
switch(type) {
case NodeType::Input:
// ИСПРАВЛЕНО: 0 входов, 1 выход
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = true;
// Размер вычисляется из глобальных настроек
layerSize = UI_CONTEXT * UI_EMBED;
break;
case NodeType::Hidden:
// ИСПРАВЛЕНО: 1 вход, 1 выход
inputs.push_back(Port("In", PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = false;
break;
case NodeType::Output:
// ИСПРАВЛЕНО: 1 вход, 0 выходов
inputs.push_back(Port("In", PortType::Input));
isSizeFixed = true;
layerSize = UI_VOCAB;
break;
case NodeType::Splitter:
inputs.push_back(Port("In", PortType::Input));
for(int i=0;i<branchCount;i++) outputs.push_back(Port("Br "+std::to_string(i), PortType::Output));
for(int i=0;i<branchCount;i++)
outputs.push_back(Port("Br "+std::to_string(i), PortType::Output));
isSizeFixed = false;
break;
case NodeType::Merger:
for(int i=0;i<branchCount;i++) inputs.push_back(Port("Br "+std::to_string(i), PortType::Input));
for(int i=0;i<branchCount;i++)
inputs.push_back(Port("Br "+std::to_string(i), PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = false;
break;
}
}
@@ -138,12 +158,25 @@ void DrawNode(ImDrawList* dl, const Node& n, const GraphState& g, const ImVec2&
dl->AddRect(scr, scr + n.size, GetNodeColor(n.type, n.selected), 6, 0, 2.0f);
dl->AddText(scr + ImVec2(10,4), IM_COL32(255,255,255,255), n.title.c_str());
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), std::to_string(n.layerSize).c_str());
// Отображение размера
if(n.isSizeFixed) {
std::string sizeInfo = std::to_string(n.layerSize) + " (auto)";
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), sizeInfo.c_str());
if(n.type == NodeType::Input) {
dl->AddText(scr + ImVec2(10,36), IM_COL32(100,100,100,150), "CTX×EMB");
} else if(n.type == NodeType::Output) {
dl->AddText(scr + ImVec2(10,36), IM_COL32(100,100,100,150), "VOCAB");
}
} else {
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), std::to_string(n.layerSize).c_str());
}
if(n.type == NodeType::Splitter || n.type == NodeType::Merger) {
dl->AddText(scr + ImVec2(10, n.size.y-18), IM_COL32(255,220,100,255), (" x"+std::to_string(n.branchCount)).c_str());
}
// Порты ввода
for(size_t i=0; i<n.inputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Input);
@@ -151,6 +184,7 @@ void DrawNode(ImDrawList* dl, const Node& n, const GraphState& g, const ImVec2&
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
dl->AddText(p + ImVec2(10,-5), IM_COL32(200,200,200,255), n.inputs[i].name.c_str());
}
// Порты вывода
for(size_t i=0; i<n.outputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Output);
@@ -206,23 +240,32 @@ void HandleInput(GraphState& g, const ImVec2& canvasPos, const ImVec2& canvasSiz
for(const auto& n : g.nodes) {
for(size_t i=0;i<n.inputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Input; }
if(ImLengthSqr(mouseScreen-p) < 64) {
g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Input;
}
}
for(size_t i=0;i<n.outputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) { g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Output; }
if(ImLengthSqr(mouseScreen-p) < 64) {
g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Output;
}
}
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode!=-1 && !g.panning) {
g.creatingConn=true; g.connStartNode=g.hoveredPortNode; g.connStartPort=g.hoveredPortIdx; g.connStartType=g.hoveredPortType;
g.creatingConn=true;
g.connStartNode=g.hoveredPortNode;
g.connStartPort=g.hoveredPortIdx;
g.connStartType=g.hoveredPortType;
}
if(g.creatingConn) {
g.connMousePos = mouseScreen;
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) {
if(g.hoveredPortNode!=-1 && g.hoveredPortNode!=g.connStartNode && g.hoveredPortType!=g.connStartType) {
if(g.connStartType==PortType::Output) g.connections.emplace_back(g.connStartNode, g.connStartPort, g.hoveredPortNode, g.hoveredPortIdx);
else g.connections.emplace_back(g.hoveredPortNode, g.hoveredPortIdx, g.connStartNode, g.connStartPort);
if(g.connStartType==PortType::Output)
g.connections.emplace_back(g.connStartNode, g.connStartPort, g.hoveredPortNode, g.hoveredPortIdx);
else
g.connections.emplace_back(g.hoveredPortNode, g.hoveredPortIdx, g.connStartNode, g.connStartPort);
}
g.creatingConn=false;
}
@@ -243,7 +286,8 @@ void HandleInput(GraphState& g, const ImVec2& canvasPos, const ImVec2& canvasSiz
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode==-1 && !g.creatingConn && !g.panning) {
bool onNode = false;
for(const auto& n : g.nodes) if(ImRect(canvasPos+n.pos+g.pan, canvasPos+n.pos+g.pan+n.size).Contains(mouseScreen)) onNode=true;
for(const auto& n : g.nodes)
if(ImRect(canvasPos+n.pos+g.pan, canvasPos+n.pos+g.pan+n.size).Contains(mouseScreen)) onNode=true;
if(!onNode) { g.selectedNode=-1; for(auto& n:g.nodes) n.selected=false; }
}
}
@@ -256,8 +300,10 @@ void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
float gs = 40.0f;
ImVec2 off = ImVec2(fmodf(g.pan.x, gs), fmodf(g.pan.y, gs));
for(float x=off.x; x<canvasSize.x; x+=gs) dl->AddLine(canvasPos+ImVec2(x,0), canvasPos+ImVec2(x,canvasSize.y), IM_COL32(40,44,55,120));
for(float y=off.y; y<canvasSize.y; y+=gs) dl->AddLine(canvasPos+ImVec2(0,y), canvasPos+ImVec2(canvasSize.x,y), IM_COL32(40,44,55,120));
for(float x=off.x; x<canvasSize.x; x+=gs)
dl->AddLine(canvasPos+ImVec2(x,0), canvasPos+ImVec2(x,canvasSize.y), IM_COL32(40,44,55,120));
for(float y=off.y; y<canvasSize.y; y+=gs)
dl->AddLine(canvasPos+ImVec2(0,y), canvasPos+ImVec2(canvasSize.x,y), IM_COL32(40,44,55,120));
DrawConnections(dl, g, canvasPos, g.pan);
for(auto& n : g.nodes) DrawNode(dl, n, g, canvasPos, g.pan);
@@ -267,13 +313,29 @@ void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
if(ImGui::BeginPopupContextItem("##Ctx")) {
ImVec2 wPos = ImGui::GetMousePos() - canvasPos - g.pan;
if(ImGui::MenuItem("🟢 Input")) { g.nodes.emplace_back(g.nextId++, "Input", NodeType::Input); g.nodes.back().pos=wPos; }
if(ImGui::MenuItem(" Hidden")) { g.nodes.emplace_back(g.nextId++, "Hidden", NodeType::Hidden); g.nodes.back().pos=wPos; }
if(ImGui::MenuItem(" Output")) { g.nodes.emplace_back(g.nextId++, "Output", NodeType::Output); g.nodes.back().pos=wPos; }
if(ImGui::MenuItem("🟢 Input")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Input", NodeType::Input);
n.pos=wPos;
}
if(ImGui::MenuItem(" Hidden")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Hidden", NodeType::Hidden);
n.pos=wPos;
}
if(ImGui::MenuItem(" Output")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Output", NodeType::Output);
n.pos=wPos;
}
ImGui::Separator();
if(ImGui::MenuItem(" Splitter (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Splitter", NodeType::Splitter); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
if(ImGui::MenuItem(" Merger (x2)")) { auto& n=g.nodes.emplace_back(g.nextId++, "Merger", NodeType::Merger); n.pos=wPos; n.branchCount=2; n.UpdatePorts(); }
ImGui::Separator(); ImGui::Text("LMB: Drag/Connect | RMB: Cancel | MMB: Pan");
if(ImGui::MenuItem(" Splitter")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Splitter", NodeType::Splitter);
n.pos=wPos; n.branchCount=2; n.UpdatePorts();
}
if(ImGui::MenuItem(" Merger")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Merger", NodeType::Merger);
n.pos=wPos; n.branchCount=2; n.UpdatePorts();
}
ImGui::Separator();
ImGui::Text("LMB: Drag/Connect | RMB: Cancel | MMB: Pan");
ImGui::EndPopup();
}
@@ -282,16 +344,32 @@ void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
if(it != g.nodes.end()) {
Node& sel = *it;
ImGui::SetNextWindowPos(canvasPos + ImVec2(10,10));
ImGui::SetNextWindowSize(ImVec2(240,220));
ImGui::SetNextWindowSize(ImVec2(240,240));
if(ImGui::Begin("##Props", nullptr, ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_AlwaysAutoResize|ImGuiWindowFlags_NoMove)) {
ImGui::TextColored(ImVec4(1,1,0.5f,1), "%s # %d", sel.title.c_str(), sel.id); ImGui::Separator();
ImGui::TextColored(ImVec4(1,1,0.5f,1), "%s # %d", sel.title.c_str(), sel.id);
ImGui::Separator();
if(sel.type != NodeType::Output) { ImGui::InputInt("Neurons", &sel.layerSize); if(sel.layerSize<1) sel.layerSize=1; }
// Редактирование размера ТОЛЬКО для не-фиксированных слоёв
if(!sel.isSizeFixed) {
ImGui::InputInt("Neurons", &sel.layerSize);
if(sel.layerSize < 1) sel.layerSize = 1;
} else {
ImGui::Text("Size: %d (auto)", sel.layerSize);
if(sel.type == NodeType::Input) {
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= CONTEXT × EMBED");
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= %d × %d", UI_CONTEXT, UI_EMBED);
} else if(sel.type == NodeType::Output) {
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= VOCAB_SIZE");
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= %d", UI_VOCAB);
}
}
if(sel.type == NodeType::Splitter || sel.type == NodeType::Merger) {
ImGui::Separator();
ImGui::SliderInt("Branches", &sel.branchCount, 2, 8);
if(ImGui::IsItemDeactivatedAfterEdit()) sel.UpdatePorts();
} else if(sel.type == NodeType::Input) {
ImGui::Separator();
ImGui::Text("Branch:");
ImGui::RadioButton("Combined", &sel.activeBranch, -1);
ImGui::RadioButton("A (0)", &sel.activeBranch, 0);
@@ -300,7 +378,8 @@ void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
ImGui::Spacing();
if(ImGui::Button("🗑 Delete", ImVec2(-1,28))) {
g.connections.erase(std::remove_if(g.connections.begin(), g.connections.end(), [id=sel.id](const Connection& c){return c.fromNode==id||c.toNode==id;}), g.connections.end());
g.connections.erase(std::remove_if(g.connections.begin(), g.connections.end(),
[id=sel.id](const Connection& c){return c.fromNode==id||c.toNode==id;}), g.connections.end());
g.nodes.erase(it); g.selectedNode=-1;
}
ImGui::End();
@@ -316,8 +395,20 @@ void SyncFromConfigs(GraphState& g, const std::vector<LayerStructure_t>& cfgs) {
NodeType t = c.sources.empty() ? NodeType::Input : (i==cfgs.size()-1 ? NodeType::Output : NodeType::Hidden);
Node n(g.nextId++, t==NodeType::Input?"Input":t==NodeType::Output?"Output":"Hidden", t);
n.pos = ImVec2(120 + i*260, 150 + (i%3)*120);
n.layerSize = c.size; n.activeBranch = c.branch;
// UpdatePorts вызывается в конструкторе, но для гарантии:
// Фиксированные размеры для Input/Output
if(t == NodeType::Input) {
n.layerSize = UI_CONTEXT * UI_EMBED;
n.isSizeFixed = true;
} else if(t == NodeType::Output) {
n.layerSize = UI_VOCAB;
n.isSizeFixed = true;
} else {
n.layerSize = c.size;
n.isSizeFixed = false;
}
n.activeBranch = c.branch;
n.UpdatePorts();
g.nodes.push_back(n);
}
@@ -337,7 +428,9 @@ void SyncToConfigs(GraphState& g, std::vector<LayerStructure_t>& cfgs) {
for(size_t i=0;i<sorted.size();i++) idToIdx[sorted[i]->id] = (int)i;
for(const Node* n : sorted) {
LayerStructure_t l; l.size=n->layerSize; l.branch=n->activeBranch;
LayerStructure_t l;
l.size = n->layerSize;
l.branch = n->activeBranch;
for(const auto& c : g.connections) {
if(c.toNode==n->id && idToIdx.count(c.fromNode)) {
l.sources.push_back(idToIdx[c.fromNode]);
@@ -362,13 +455,21 @@ std::string GenerateArchitectureText(const GraphState& g, const std::vector<Laye
}
ss << "Total Parameters: " << totalParams << " (" << std::fixed << std::setprecision(2) << (totalParams/1000000.0) << "M)\n\n";
ss << "Layer Structure:\n";
ss << "----------------\n";
ss << "Layer Structure:\n----------------\n";
for(size_t i=0;i<configs.size();i++) {
auto& l = configs[i];
std::string type = l.sources.empty() ? "INPUT" : (i==configs.size()-1 ? "OUTPUT" : "HIDDEN");
ss << "Layer " << i << " [" << type << "]: " << l.size << " neurons";
if(type == "INPUT") {
ss << "Layer " << i << " [INPUT]: " << l.size << " neurons";
ss << " (CONTEXT=" << UI_CONTEXT << " × EMBED=" << UI_EMBED << ")";
} else if(type == "OUTPUT") {
ss << "Layer " << i << " [OUTPUT]: " << l.size << " neurons";
ss << " (VOCAB=" << UI_VOCAB << ")";
} else {
ss << "Layer " << i << " [HIDDEN]: " << l.size << " neurons";
}
if(l.branch != -1) ss << " (Branch " << (char)('A'+l.branch) << ")";
ss << "\n";
@@ -389,8 +490,9 @@ std::string GenerateArchitectureText(const GraphState& g, const std::vector<Laye
for(const auto& n : g.nodes) {
ss << "Node #" << n.id << " [" << n.title << "] @ ("
<< (int)n.pos.x << "," << (int)n.pos.y << ")\n";
ss << " Size: " << n.layerSize << " neurons\n";
ss << " Inputs: " << n.inputs.size() << " | Outputs: " << n.outputs.size() << "\n";
ss << " Size: " << n.layerSize << " neurons";
if(n.isSizeFixed) ss << " (auto)";
ss << "\n Inputs: " << n.inputs.size() << " | Outputs: " << n.outputs.size() << "\n";
if(n.type == NodeType::Splitter || n.type == NodeType::Merger)
ss << " Branches: " << n.branchCount << "\n";
}
@@ -398,7 +500,6 @@ std::string GenerateArchitectureText(const GraphState& g, const std::vector<Laye
return ss.str();
}
// === SAVE/LOAD ARCHITECTURE AS TEXT STRUCTURE ===
bool SaveArchitectureToFile(const std::vector<LayerStructure_t>& configs, const std::string& filename) {
std::ofstream file(filename);
if(!file.is_open()) return false;
@@ -538,9 +639,6 @@ bool LoadDatasetFromFile(std::string& buffer, const std::string& filename) {
// GLOBAL STATE
// ============================================================================
int UI_CONTEXT = 256, UI_EMBED = 8, UI_VOCAB = 300;
int MAX_RESPONSE_TOKENS = 50;
struct UIState {
std::vector<ChatSession> chats;
int activeChat = 0;
@@ -553,11 +651,11 @@ struct UIState {
int epochs = 10;
int doneEpochs = 0;
// Performance metrics
double tokensPerSecond = 0.0;
double generationTokensPerSecond = 0.0;
int totalTokensTrained = 0;
std::chrono::steady_clock::time_point trainingStartTime;
std::chrono::steady_clock::time_point lastTokenTime;
char dsBuf[524288] = "";
std::vector<TrainingSample> parsedSamples;
@@ -568,7 +666,6 @@ struct UIState {
NodeEditor::GraphState graph;
std::string architectureText;
bool showArchitecture = true;
char archFilePath[256] = "architecture.txt";
char datasetFilePath[256] = "dataset.txt";
@@ -609,7 +706,6 @@ void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
int os = ui.layers.back().size;
double totalLoss = 0;
int sampleCount = 0;
int totalTokens = 0;
ui.trainingStartTime = std::chrono::steady_clock::now();
ui.totalTokensTrained = 0;
@@ -633,7 +729,6 @@ void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
double loss = nn->train(PrepInput(context, *eb), target, ui.lr);
totalLoss += loss;
sampleCount++;
totalTokens++;
ui.totalTokensTrained++;
context.push_back(aiToks[i]);
@@ -663,19 +758,16 @@ void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
ui.doneEpochs++;
// Логируем каждую эпоху
std::cout << "Epoch " << (e+1) << "/" << ui.epochs
<< " | Loss: " << (totalLoss/sampleCount)
<< " | Loss: " << (totalLoss/(sampleCount>0?sampleCount:1))
<< " | Time: " << epochTime << "s"
<< " | Tokens/sec: " << ui.tokensPerSecond
<< "\n";
<< " | Tokens/sec: " << ui.tokensPerSecond << "\n";
}
if(sampleCount > 0) {
std::lock_guard<std::mutex> lk(ui.mtx);
ui.lastStatus.totalLoss = totalLoss / sampleCount;
// Общая статистика
auto totalTime = std::chrono::duration<double>(
std::chrono::steady_clock::now() - ui.trainingStartTime).count();
if(totalTime > 0) {
@@ -686,6 +778,7 @@ void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
ui.training=false;
ui.stop=false;
}
// ============================================================================
// MAIN
// ============================================================================
@@ -700,6 +793,7 @@ int main() {
io.Fonts->AddFontFromFileTTF("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic());
ImGui_ImplGlfw_InitForOpenGL(win, true); ImGui_ImplOpenGL3_Init("#version 130");
// Init default architecture
ui.layers = {
LayerStructure_t(UI_CONTEXT*UI_EMBED,{}),
LayerStructure_t(512,{0}),
@@ -721,6 +815,9 @@ int main() {
std::cout << "Loaded architecture from " << ui.archFilePath << "\n";
}
// Track last values for auto-update
int lastContext = UI_CONTEXT, lastEmbed = UI_EMBED, lastVocab = UI_VOCAB;
while(!glfwWindowShouldClose(win)) {
glfwPollEvents();
ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame();
@@ -750,6 +847,7 @@ int main() {
ImGui::EndMenuBar();
}
// Status bar
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Epoch: %d/%d | Loss: %.5f | %.1f%%",
@@ -761,6 +859,7 @@ int main() {
}
ImGui::Separator();
// Three columns
ImGui::Columns(3, "MainCols", true);
ImGui::SetColumnWidth(0, io.DisplaySize.x*0.35f);
ImGui::SetColumnWidth(1, io.DisplaySize.x*0.35f);
@@ -792,18 +891,18 @@ int main() {
}
if(ImGui::BeginTabItem("Training")) {
// Performance metrics
ImGui::TextColored(ImVec4(0,1,1,1), "Performance Metrics:");
ImGui::Separator();
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Training Speed: %.2f tokens/sec", ui.tokensPerSecond);
ImGui::Text("Generation Speed: %.2f tokens/sec", ui.generationTokensPerSecond);
ImGui::Text("Total Tokens Trained: %d", ui.totalTokensTrained);
}
ImGui::Separator();
// Settings
ImGui::SliderFloat("Learning Rate", &ui.lr, 0.0001f, 0.1f, "%.5f");
ImGui::InputInt("Epochs", &ui.epochs); if(ui.epochs<1) ui.epochs=1;
ImGui::InputInt("Max Response Tokens", &MAX_RESPONSE_TOKENS); if(MAX_RESPONSE_TOKENS<1) MAX_RESPONSE_TOKENS=1;
@@ -811,6 +910,21 @@ int main() {
ImGui::InputInt("Embedding Dim", &UI_EMBED); if(UI_EMBED<1) UI_EMBED=1;
ImGui::InputInt("Vocab Size", &UI_VOCAB); if(UI_VOCAB<1) UI_VOCAB=1;
// Auto-update fixed layer sizes when globals change
if(UI_CONTEXT != lastContext || UI_EMBED != lastEmbed || UI_VOCAB != lastVocab) {
lastContext = UI_CONTEXT; lastEmbed = UI_EMBED; lastVocab = UI_VOCAB;
for(auto& node : ui.graph.nodes) {
if(node.type == NodeEditor::NodeType::Input) {
node.layerSize = UI_CONTEXT * UI_EMBED;
node.UpdatePorts();
} else if(node.type == NodeEditor::NodeType::Output) {
node.layerSize = UI_VOCAB;
node.UpdatePorts();
}
}
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
ImGui::Separator();
ImGui::TextColored(ImVec4(1,0.5f,0,1), "⚠️ Recommendations:");
if(ui.parsedSamples.size() < 10) {
@@ -822,15 +936,17 @@ int main() {
if(ui.lr > 0.01f) {
ImGui::TextColored(ImVec4(1,0.8f,0,1), "• High learning rate! Try 0.001-0.01");
}
// Dataset loading
ImGui::Separator();
ImGui::Text("Dataset File:");
ImGui::InputText("##DSFile", ui.datasetFilePath, 256);
if(ImGui::Button("Load Dataset from File")) {
std::string tempBuffer; // 1. Создаем переменную для загрузки
std::string tempBuffer;
if(LoadDatasetFromFile(tempBuffer, ui.datasetFilePath)) {
// 2. Копируем данные из переменной в буфер ImGui
// Используем strncpy, чтобы не переполнить массив ui.dsBuf
strncpy(ui.dsBuf, tempBuffer.c_str(), sizeof(ui.dsBuf) - 1);
ui.dsBuf[sizeof(ui.dsBuf) - 1] = '\0'; // Гарантируем завершение строки нулем
// 3. Парсим обновленный буфер
ui.dsBuf[sizeof(ui.dsBuf) - 1] = '\0';
ui.parsedSamples = ParseDataset(ui.dsBuf, tok);
ui.datasetLoaded = true;
std::cout << "Loaded " << ui.parsedSamples.size() << " samples\n";
@@ -954,16 +1070,13 @@ int main() {
auto genEndTime = std::chrono::steady_clock::now();
double genTime = std::chrono::duration<double>(genEndTime - genStartTime).count();
// Добавляем информацию о скорости в лог (опционально)
std::cout << "Generated " << generatedTokens << " tokens in "
<< genTime << "s ("
<< (generatedTokens/genTime) << " tok/s)\n";
<< genTime << "s (" << (generatedTokens/genTime) << " tok/s)\n";
chat.messages.push_back({"AI", ans});
ui.inputBuf[0]=0;
ui.scrollChat=true;
}
}
ImGui::EndChild();