V1.0
This commit is contained in:
+265
-177
@@ -1,34 +1,57 @@
|
||||
#include "core.hpp"
|
||||
#include "token/token.hpp"
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <chrono>
|
||||
#include <string.h>
|
||||
#include <omp.h>
|
||||
#include <fstream>
|
||||
#include <cstring>
|
||||
|
||||
int NeuralNetwork::calculateTotalInputSize(int layerIdx) const {
|
||||
int total = 0;
|
||||
for (size_t i = 0; i < layerSources[layerIdx].size(); i++) {
|
||||
int srcIdx = layerSources[layerIdx][i];
|
||||
total += sizes[srcIdx];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
NeuralNetwork::NeuralNetwork(LayerStructure_t layers[], int count, bool useVulkanParam) {
|
||||
this->numLayers = count;
|
||||
this->useVulkan = useVulkanParam;
|
||||
|
||||
uint32_t curW = 0, curB = 0, curO = 0;
|
||||
|
||||
// 1. Индексация нейронов и веток
|
||||
for (int i = 0; i < count; i++) {
|
||||
sizes.push_back(layers[i].size);
|
||||
layerSources.push_back(layers[i].sources);
|
||||
layerSourceBranches.push_back(layers[i].sourceBranches);
|
||||
branches.push_back(layers[i].branch);
|
||||
splits.push_back(layers[i].isSplit);
|
||||
oOff.push_back(curO);
|
||||
curO += layers[i].size;
|
||||
}
|
||||
|
||||
for (int i = 0; i < count - 1; i++) {
|
||||
// 2. Инициализация весов
|
||||
for (int i = 0; i < count; i++) {
|
||||
if (layerSources[i].empty()) {
|
||||
wOff.push_back(0);
|
||||
bOff.push_back(0);
|
||||
continue;
|
||||
}
|
||||
|
||||
wOff.push_back(curW);
|
||||
bOff.push_back(curB);
|
||||
int wCount = sizes[i] * sizes[i+1];
|
||||
float scale = sqrt(2.0f / sizes[i]);
|
||||
|
||||
int totalInSize = calculateTotalInputSize(i);
|
||||
int wCount = totalInSize * sizes[i];
|
||||
float scale = sqrt(2.0f / totalInSize);
|
||||
|
||||
for (int j = 0; j < wCount; j++) {
|
||||
h_weights.push_back(((float)rand() / RAND_MAX * 2.0f - 1.0f) * scale);
|
||||
}
|
||||
for (int j = 0; j < sizes[i+1]; j++) h_biases.push_back(0.0f);
|
||||
for (int j = 0; j < sizes[i]; j++) h_biases.push_back(0.0f);
|
||||
|
||||
curW += wCount;
|
||||
curB += sizes[i+1];
|
||||
curB += sizes[i];
|
||||
}
|
||||
|
||||
h_outputs.resize(curO, 0.0f);
|
||||
@@ -36,216 +59,281 @@ NeuralNetwork::NeuralNetwork(LayerStructure_t layers[], int count, bool useVulka
|
||||
|
||||
if (this->useVulkan) {
|
||||
initVulkan();
|
||||
if (this->useVulkan) {
|
||||
initVulkanResources();
|
||||
syncToGPU();
|
||||
}
|
||||
initVulkanResources();
|
||||
syncToGPU();
|
||||
}
|
||||
}
|
||||
|
||||
void NeuralNetwork::initVulkan() {
|
||||
try {
|
||||
vk::ApplicationInfo app{"Xenith", 1, nullptr, 0, VK_API_VERSION_1_1};
|
||||
instance = vk::createInstance({{}, &app});
|
||||
auto pdevs = instance.enumeratePhysicalDevices();
|
||||
if (pdevs.empty()) throw std::runtime_error("GPU not found");
|
||||
physDev = pdevs[0];
|
||||
double NeuralNetwork::train(const std::map<int, std::vector<double>>& inputs,
|
||||
const std::vector<double>& target, double lr) {
|
||||
if (!useVulkan) return 0.0;
|
||||
|
||||
auto props = physDev.getQueueFamilyProperties();
|
||||
computeQueueFamilyIndex = -1;
|
||||
for (uint32_t i = 0; i < props.size(); i++) {
|
||||
if (props[i].queueFlags & vk::QueueFlagBits::eCompute) {
|
||||
computeQueueFamilyIndex = i; break;
|
||||
}
|
||||
// Подготовка входных данных
|
||||
for (auto const& [layerIdx, data] : inputs) {
|
||||
if (layerIdx >= 0 && layerIdx < numLayers) {
|
||||
float* ptr = (float*)pO + oOff[layerIdx];
|
||||
size_t copySize = std::min(data.size(), (size_t)sizes[layerIdx]);
|
||||
for (size_t i = 0; i < copySize; i++) ptr[i] = (float)data[i];
|
||||
}
|
||||
vk::DeviceQueueCreateInfo qinfo({}, (uint32_t)computeQueueFamilyIndex, 1, new float{1.0f});
|
||||
device = physDev.createDevice({{}, 1, &qinfo});
|
||||
queue = device.getQueue(computeQueueFamilyIndex, 0);
|
||||
cmdPool = device.createCommandPool({{}, (uint32_t)computeQueueFamilyIndex});
|
||||
} catch (...) { useVulkan = false; }
|
||||
}
|
||||
|
||||
// Подготовка target
|
||||
float* fTar = (float*)pT;
|
||||
for (size_t i = 0; i < target.size(); i++) fTar[i] = (float)target[i];
|
||||
|
||||
// Получаем command buffer из пула
|
||||
if (cmdBuffers.empty()) {
|
||||
vk::CommandBufferAllocateInfo ai(cmdPool, vk::CommandBufferLevel::ePrimary, 4);
|
||||
cmdBuffers = device.allocateCommandBuffers(ai);
|
||||
}
|
||||
vk::CommandBuffer cmd = cmdBuffers[currentCmdBuffer];
|
||||
currentCmdBuffer = (currentCmdBuffer + 1) % cmdBuffers.size();
|
||||
|
||||
cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit});
|
||||
cmd.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
|
||||
cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeLayout, 0, {descriptorSet}, {});
|
||||
|
||||
vk::MemoryBarrier barrier(vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eShaderRead,
|
||||
vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eShaderRead);
|
||||
|
||||
// 1. FORWARD PASS
|
||||
for (int i = 0; i < numLayers; i++) {
|
||||
if (layerSources[i].empty()) continue;
|
||||
|
||||
int totalIn = calculateTotalInputSize(i);
|
||||
int firstSrc = layerSources[i][0];
|
||||
|
||||
TrainParams p = {0, (uint32_t)totalIn, (uint32_t)sizes[i], wOff[i], bOff[i],
|
||||
oOff[firstSrc], oOff[i], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i] + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
|
||||
// 2. OUTPUT ERROR
|
||||
{
|
||||
TrainParams p = {1, 0, (uint32_t)sizes.back(), 0, 0, 0, oOff.back(), (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes.back() + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
|
||||
// 3. BACKWARD PASS
|
||||
for (int i = numLayers - 1; i >= 0; i--) {
|
||||
if (layerSources[i].empty()) continue;
|
||||
|
||||
int totalIn = calculateTotalInputSize(i);
|
||||
int firstSrc = layerSources[i][0];
|
||||
|
||||
TrainParams p = {2, (uint32_t)totalIn, (uint32_t)sizes[i], wOff[i], bOff[i],
|
||||
oOff[firstSrc], oOff[i], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((totalIn + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
|
||||
// 4. UPDATE WEIGHTS
|
||||
for (int i = 0; i < numLayers; i++) {
|
||||
if (layerSources[i].empty()) continue;
|
||||
|
||||
int totalIn = calculateTotalInputSize(i);
|
||||
int firstSrc = layerSources[i][0];
|
||||
|
||||
TrainParams p = {3, (uint32_t)totalIn, (uint32_t)sizes[i], wOff[i], bOff[i],
|
||||
oOff[firstSrc], oOff[i], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i] + 255) / 256, 1, 1);
|
||||
}
|
||||
|
||||
cmd.end();
|
||||
queue.submit(vk::SubmitInfo(0, nullptr, nullptr, 1, &cmd), nullptr);
|
||||
queue.waitIdle();
|
||||
|
||||
// Расчет MSE
|
||||
float* out = (float*)pO + oOff.back();
|
||||
double mse = 0;
|
||||
for (int i = 0; i < sizes.back(); i++) {
|
||||
double d = (double)target[i] - (double)out[i];
|
||||
mse += d * d;
|
||||
}
|
||||
return mse / sizes.back();
|
||||
}
|
||||
|
||||
std::vector<double> NeuralNetwork::feedForward(const std::map<int, std::vector<double>>& inputs) {
|
||||
if (!useVulkan) return {};
|
||||
|
||||
for (auto const& [layerIdx, data] : inputs) {
|
||||
if (layerIdx >= 0 && layerIdx < numLayers) {
|
||||
float* ptr = (float*)pO + oOff[layerIdx];
|
||||
size_t copySize = std::min(data.size(), (size_t)sizes[layerIdx]);
|
||||
memcpy(ptr, data.data(), copySize * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
if (cmdBuffers.empty()) {
|
||||
vk::CommandBufferAllocateInfo ai(cmdPool, vk::CommandBufferLevel::ePrimary, 4);
|
||||
cmdBuffers = device.allocateCommandBuffers(ai);
|
||||
}
|
||||
vk::CommandBuffer cmd = cmdBuffers[currentCmdBuffer];
|
||||
currentCmdBuffer = (currentCmdBuffer + 1) % cmdBuffers.size();
|
||||
|
||||
cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit});
|
||||
cmd.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
|
||||
cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeLayout, 0, {descriptorSet}, {});
|
||||
|
||||
vk::MemoryBarrier barrier(vk::AccessFlagBits::eShaderWrite, vk::AccessFlagBits::eShaderRead);
|
||||
|
||||
for (int i = 0; i < numLayers; i++) {
|
||||
if (layerSources[i].empty()) continue;
|
||||
|
||||
int totalIn = calculateTotalInputSize(i);
|
||||
int firstSrc = layerSources[i][0];
|
||||
|
||||
TrainParams p = {0, (uint32_t)totalIn, (uint32_t)sizes[i], wOff[i], bOff[i],
|
||||
oOff[firstSrc], oOff[i], 0.0f};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i] + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
|
||||
cmd.end();
|
||||
queue.submit(vk::SubmitInfo(0, nullptr, nullptr, 1, &cmd), nullptr);
|
||||
queue.waitIdle();
|
||||
|
||||
std::vector<double> result(sizes.back());
|
||||
float* out = (float*)pO + oOff.back();
|
||||
for (int i = 0; i < sizes.back(); i++) result[i] = (double)out[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
void NeuralNetwork::initVulkan() {
|
||||
vk::ApplicationInfo app{"Xenith", 1, nullptr, 0, VK_API_VERSION_1_1};
|
||||
instance = vk::createInstance({{}, &app});
|
||||
physDev = instance.enumeratePhysicalDevices()[0];
|
||||
auto props = physDev.getQueueFamilyProperties();
|
||||
int qIdx = -1;
|
||||
for (int i = 0; i < props.size(); i++)
|
||||
if (props[i].queueFlags & vk::QueueFlagBits::eCompute) {
|
||||
qIdx = i;
|
||||
break;
|
||||
}
|
||||
float priority = 1.0f;
|
||||
device = physDev.createDevice({{}, 1, new vk::DeviceQueueCreateInfo({}, (uint32_t)qIdx, 1, &priority)});
|
||||
queue = device.getQueue(qIdx, 0);
|
||||
cmdPool = device.createCommandPool({{}, (uint32_t)qIdx});
|
||||
}
|
||||
|
||||
void NeuralNetwork::initVulkanResources() {
|
||||
auto createBuf = [&](size_t sz, vk::Buffer& b, vk::DeviceMemory& m, void** ptr) {
|
||||
if (sz == 0) sz = 1;
|
||||
b = device.createBuffer({{}, sz * sizeof(float), vk::BufferUsageFlagBits::eStorageBuffer});
|
||||
auto req = device.getBufferMemoryRequirements(b);
|
||||
m = device.allocateMemory({req.size, findMemoryType(req.memoryTypeBits, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent)});
|
||||
m = device.allocateMemory({req.size, findMemoryType(req.memoryTypeBits,
|
||||
vk::MemoryPropertyFlagBits::eHostVisible |
|
||||
vk::MemoryPropertyFlagBits::eHostCoherent)});
|
||||
device.bindBufferMemory(b, m, 0);
|
||||
*ptr = device.mapMemory(m, 0, sz * sizeof(float));
|
||||
};
|
||||
|
||||
createBuf(h_weights.size(), gpuW, memW, &pW);
|
||||
createBuf(h_biases.size(), gpuB, memB, &pB);
|
||||
createBuf(h_outputs.size(), gpuO, memO, &pO);
|
||||
createBuf(h_errors.size(), gpuE, memE, &pE);
|
||||
createBuf(sizes.back(), gpuT, memT, &pT);
|
||||
|
||||
std::vector<vk::DescriptorSetLayoutBinding> binds = {
|
||||
vk::DescriptorSetLayoutBinding binds[] = {
|
||||
{0, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{1, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{2, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{3, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{4, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}
|
||||
};
|
||||
dsLayout = device.createDescriptorSetLayout({{}, (uint32_t)binds.size(), binds.data()});
|
||||
dsLayout = device.createDescriptorSetLayout({{}, 5, binds});
|
||||
vk::DescriptorPoolSize ps(vk::DescriptorType::eStorageBuffer, 5);
|
||||
descriptorPool = device.createDescriptorPool({{}, 1, 1, &ps});
|
||||
descriptorSet = device.allocateDescriptorSets({descriptorPool, 1, &dsLayout})[0];
|
||||
vk::DescriptorBufferInfo bW(gpuW, 0, VK_WHOLE_SIZE), bB(gpuB, 0, VK_WHOLE_SIZE), bO(gpuO, 0, VK_WHOLE_SIZE), bE(gpuE, 0, VK_WHOLE_SIZE), bT(gpuT, 0, VK_WHOLE_SIZE);
|
||||
device.updateDescriptorSets({
|
||||
{descriptorSet, 0, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &bW},
|
||||
{descriptorSet, 1, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &bB},
|
||||
{descriptorSet, 2, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &bO},
|
||||
{descriptorSet, 3, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &bE},
|
||||
{descriptorSet, 4, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &bT}
|
||||
}, {});
|
||||
|
||||
vk::DescriptorBufferInfo db[] = {
|
||||
{gpuW,0,VK_WHOLE_SIZE}, {gpuB,0,VK_WHOLE_SIZE}, {gpuO,0,VK_WHOLE_SIZE},
|
||||
{gpuE,0,VK_WHOLE_SIZE}, {gpuT,0,VK_WHOLE_SIZE}
|
||||
};
|
||||
vk::WriteDescriptorSet wds[] = {
|
||||
{descriptorSet,0,0,1,vk::DescriptorType::eStorageBuffer,nullptr,db+0},
|
||||
{descriptorSet,1,0,1,vk::DescriptorType::eStorageBuffer,nullptr,db+1},
|
||||
{descriptorSet,2,0,1,vk::DescriptorType::eStorageBuffer,nullptr,db+2},
|
||||
{descriptorSet,3,0,1,vk::DescriptorType::eStorageBuffer,nullptr,db+3},
|
||||
{descriptorSet,4,0,1,vk::DescriptorType::eStorageBuffer,nullptr,db+4}
|
||||
};
|
||||
device.updateDescriptorSets(5, wds, 0, nullptr);
|
||||
|
||||
auto code = readFile("Xenith/shader.comp.spv");
|
||||
if (code.empty()) {
|
||||
std::cerr << "ERROR: Failed to load shader.comp.spv!\n";
|
||||
exit(1);
|
||||
}
|
||||
shaderModule = device.createShaderModule({{}, code.size(), (uint32_t*)code.data()});
|
||||
vk::PushConstantRange pr(vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams));
|
||||
pipeLayout = device.createPipelineLayout({{}, 1, &dsLayout, 1, &pr});
|
||||
pipeline = device.createComputePipeline(nullptr, {{}, {{}, vk::ShaderStageFlagBits::eCompute, shaderModule, "main"}, pipeLayout}).value;
|
||||
pipeline = device.createComputePipeline(nullptr, {{}, {{}, vk::ShaderStageFlagBits::eCompute,
|
||||
shaderModule, "main"}, pipeLayout}).value;
|
||||
}
|
||||
|
||||
double NeuralNetwork::train(const std::vector<double>& input, const std::vector<double>& target, double lr) {
|
||||
if (!useVulkan) return runTrainCPU(input, target, lr);
|
||||
float* fIn = (float*)pO; for(size_t i=0; i<input.size(); i++) fIn[i] = (float)input[i];
|
||||
float* fTar = (float*)pT; for(size_t i=0; i<target.size(); i++) fTar[i] = (float)target[i];
|
||||
|
||||
vk::CommandBufferAllocateInfo ai(cmdPool, vk::CommandBufferLevel::ePrimary, 1);
|
||||
vk::CommandBuffer cmd = device.allocateCommandBuffers(ai)[0];
|
||||
cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit});
|
||||
cmd.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
|
||||
cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeLayout, 0, {descriptorSet}, {});
|
||||
|
||||
vk::MemoryBarrier barrier(vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eShaderRead);
|
||||
for (int i = 0; i < numLayers - 1; i++) {
|
||||
TrainParams p = {0, (uint32_t)sizes[i], (uint32_t)sizes[i+1], wOff[i], bOff[i], oOff[i], oOff[i+1], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i+1] + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
{
|
||||
TrainParams p = {1, 0, (uint32_t)sizes.back(), 0, 0, 0, oOff.back(), (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes.back() + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
for (int i = numLayers - 2; i > 0; i--) {
|
||||
TrainParams p = {2, (uint32_t)sizes[i], (uint32_t)sizes[i+1], wOff[i], bOff[i], oOff[i], oOff[i+1], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i] + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, {}, {barrier}, {}, {});
|
||||
}
|
||||
for (int i = 0; i < numLayers - 1; i++) {
|
||||
TrainParams p = {3, (uint32_t)sizes[i], (uint32_t)sizes[i+1], wOff[i], bOff[i], oOff[i], oOff[i+1], (float)lr};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i+1] + 255) / 256, 1, 1);
|
||||
}
|
||||
cmd.end();
|
||||
queue.submit(vk::SubmitInfo(0, nullptr, nullptr, 1, &cmd), nullptr);
|
||||
queue.waitIdle();
|
||||
device.freeCommandBuffers(cmdPool, cmd);
|
||||
|
||||
double mse = 0;
|
||||
float* out = (float*)pO + oOff.back();
|
||||
for (int i = 0; i < sizes.back(); i++) { double d = (double)target[i] - (double)out[i]; mse += d * d; }
|
||||
return mse / sizes.back();
|
||||
}
|
||||
|
||||
void NeuralNetwork::trainOnSequence(Tokenizer& tok, Embedder& emb, const std::string& dataset,
|
||||
int epochs, double lr,
|
||||
std::function<std::vector<double>(const std::vector<int>&, Embedder&)> buildInput,
|
||||
std::function<void(const TrainStatus&)> onProgress) {
|
||||
std::vector<int> tokens = tok.textToTokens(dataset);
|
||||
if (tokens.size() < 2) return;
|
||||
|
||||
int clrId = -1;
|
||||
auto search = tok.textToTokens("[CLR]"); if(!search.empty()) clrId = search[0];
|
||||
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
long long totalSteps = (long long)epochs * (tokens.size() - 1);
|
||||
long long currentGlobalStep = 0;
|
||||
double lastEpochLoss = 0;
|
||||
|
||||
long long totalParamsCount = 0;
|
||||
for (int i = 0; i < numLayers - 1; i++) {
|
||||
totalParamsCount += (long long)sizes[i] * sizes[i+1]; // веса
|
||||
totalParamsCount += (long long)sizes[i+1]; // смещения
|
||||
}
|
||||
|
||||
for (int e = 1; e <= epochs; e++) {
|
||||
double currentEpochLoss = 0;
|
||||
std::vector<int> slidingContext;
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
if (tokens[i] == clrId) { slidingContext.clear(); continue; }
|
||||
if (!slidingContext.empty()) {
|
||||
std::vector<double> target(MAX_VOCAB, 0.0);
|
||||
target[tokens[i]] = 1.0;
|
||||
double loss = this->train(buildInput(slidingContext, emb), target, lr);
|
||||
currentEpochLoss += loss;
|
||||
currentGlobalStep++;
|
||||
if (onProgress && currentGlobalStep % 10 == 0) {
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
double elapsed = std::chrono::duration<double>(now - startTime).count();
|
||||
double speed = currentGlobalStep / elapsed;
|
||||
TrainStatus status = {e, epochs, (int)i, (int)tokens.size(), loss, currentEpochLoss, lastEpochLoss, speed, (totalSteps - currentGlobalStep) / speed, (float)currentGlobalStep / totalSteps * 100.0f, totalParamsCount};
|
||||
onProgress(status);
|
||||
}
|
||||
}
|
||||
slidingContext.push_back(tokens[i]);
|
||||
if (slidingContext.size() > MAX_CONTEXT) slidingContext.erase(slidingContext.begin());
|
||||
}
|
||||
lastEpochLoss = currentEpochLoss;
|
||||
}
|
||||
if (useVulkan) syncToCPU();
|
||||
}
|
||||
|
||||
std::vector<double> NeuralNetwork::feedForward(const std::vector<double>& input) {
|
||||
if (useVulkan) {
|
||||
float* fIn = (float*)pO; for(size_t i=0; i<input.size(); i++) fIn[i] = (float)input[i];
|
||||
vk::CommandBufferAllocateInfo ai(cmdPool, vk::CommandBufferLevel::ePrimary, 1);
|
||||
vk::CommandBuffer cmd = device.allocateCommandBuffers(ai)[0];
|
||||
cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit});
|
||||
cmd.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
|
||||
cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeLayout, 0, {descriptorSet}, {});
|
||||
vk::MemoryBarrier b(vk::AccessFlagBits::eShaderWrite, vk::AccessFlagBits::eShaderRead);
|
||||
for (int i = 0; i < numLayers - 1; i++) {
|
||||
TrainParams p = {0, (uint32_t)sizes[i], (uint32_t)sizes[i+1], wOff[i], bOff[i], oOff[i], oOff[i+1], 0.0f};
|
||||
cmd.pushConstants(pipeLayout, vk::ShaderStageFlagBits::eCompute, 0, sizeof(TrainParams), &p);
|
||||
cmd.dispatch((sizes[i+1] + 255) / 256, 1, 1);
|
||||
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, {}, {b}, {}, {});
|
||||
}
|
||||
cmd.end();
|
||||
queue.submit(vk::SubmitInfo(0, nullptr, nullptr, 1, &cmd), nullptr);
|
||||
queue.waitIdle();
|
||||
device.freeCommandBuffers(cmdPool, cmd);
|
||||
std::vector<double> res(sizes.back());
|
||||
float* out = (float*)pO + oOff.back();
|
||||
for(int i=0; i<sizes.back(); i++) res[i] = (double)out[i];
|
||||
return res;
|
||||
}
|
||||
return std::vector<double>(sizes.back(), 0.0);
|
||||
}
|
||||
|
||||
void NeuralNetwork::syncToCPU() { if(useVulkan) { memcpy(h_weights.data(), pW, h_weights.size()*4); memcpy(h_biases.data(), pB, h_biases.size()*4); } }
|
||||
void NeuralNetwork::syncToGPU() { if(useVulkan) { memcpy(pW, h_weights.data(), h_weights.size()*4); memcpy(pB, h_biases.data(), h_biases.size()*4); } }
|
||||
uint32_t NeuralNetwork::findMemoryType(uint32_t f, vk::MemoryPropertyFlags p) {
|
||||
auto m = physDev.getMemoryProperties();
|
||||
for(uint32_t i=0; i<m.memoryTypeCount; i++) if((f&(1<<i)) && (m.memoryTypes[i].propertyFlags&p)==p) return i;
|
||||
auto props = physDev.getMemoryProperties();
|
||||
for (uint32_t i = 0; i < props.memoryTypeCount; i++)
|
||||
if ((f & (1 << i)) && (props.memoryTypes[i].propertyFlags & p) == p) return i;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<char> NeuralNetwork::readFile(const std::string& n) {
|
||||
std::ifstream f(n, std::ios::ate|std::ios::binary);
|
||||
size_t s = (size_t)f.tellg(); std::vector<char> b(s); f.seekg(0); f.read(b.data(), s); return b;
|
||||
std::ifstream f(n, std::ios::ate | std::ios::binary);
|
||||
if (!f.is_open()) {
|
||||
std::cerr << "ERROR: Cannot open file: " << n << "\n";
|
||||
return {};
|
||||
}
|
||||
size_t s = (size_t)f.tellg();
|
||||
std::vector<char> b(s);
|
||||
f.seekg(0);
|
||||
f.read(b.data(), s);
|
||||
return b;
|
||||
}
|
||||
|
||||
void NeuralNetwork::syncToCPU() {
|
||||
memcpy(h_weights.data(), pW, h_weights.size() * 4);
|
||||
memcpy(h_biases.data(), pB, h_biases.size() * 4);
|
||||
}
|
||||
|
||||
void NeuralNetwork::syncToGPU() {
|
||||
memcpy(pW, h_weights.data(), h_weights.size() * 4);
|
||||
memcpy(pB, h_biases.data(), h_biases.size() * 4);
|
||||
}
|
||||
|
||||
NeuralNetwork::~NeuralNetwork() {
|
||||
if (useVulkan) {
|
||||
device.waitIdle();
|
||||
device.destroyPipeline(pipeline); device.destroyPipelineLayout(pipeLayout); device.destroyShaderModule(shaderModule);
|
||||
device.destroyBuffer(gpuW); device.freeMemory(memW); device.destroyBuffer(gpuB); device.freeMemory(memB);
|
||||
device.destroyBuffer(gpuO); device.freeMemory(memO); device.destroyBuffer(gpuE); device.freeMemory(memE); device.destroyBuffer(gpuT); device.freeMemory(memT);
|
||||
device.destroyDescriptorPool(descriptorPool); device.destroyDescriptorSetLayout(dsLayout); device.destroyCommandPool(cmdPool);
|
||||
device.destroy(); instance.destroy();
|
||||
|
||||
if (!cmdBuffers.empty()) {
|
||||
device.freeCommandBuffers(cmdPool, cmdBuffers);
|
||||
}
|
||||
|
||||
device.destroyPipeline(pipeline);
|
||||
device.destroyPipelineLayout(pipeLayout);
|
||||
device.destroyShaderModule(shaderModule);
|
||||
|
||||
device.destroyBuffer(gpuW); device.freeMemory(memW);
|
||||
device.destroyBuffer(gpuB); device.freeMemory(memB);
|
||||
device.destroyBuffer(gpuO); device.freeMemory(memO);
|
||||
device.destroyBuffer(gpuE); device.freeMemory(memE);
|
||||
device.destroyBuffer(gpuT); device.freeMemory(memT);
|
||||
|
||||
device.destroyDescriptorPool(descriptorPool);
|
||||
device.destroyDescriptorSetLayout(dsLayout);
|
||||
device.destroyCommandPool(cmdPool);
|
||||
device.destroy();
|
||||
instance.destroy();
|
||||
}
|
||||
}
|
||||
double NeuralNetwork::runTrainCPU(const std::vector<double>& i, const std::vector<double>& t, double l) { return 0.0; }
|
||||
}
|
||||
+93
-72
@@ -1,101 +1,122 @@
|
||||
#ifndef CORE_H
|
||||
#define CORE_H
|
||||
#ifndef CORE_HPP
|
||||
#define CORE_HPP
|
||||
|
||||
#include "typedef.hpp"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vulkan/vulkan.hpp>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
struct TrainStatus {
|
||||
int currentEpoch;
|
||||
int totalEpochs;
|
||||
int currentToken;
|
||||
int totalTokens;
|
||||
double currentLoss;
|
||||
double epochLoss;
|
||||
double lastEpochLoss;
|
||||
double speed;
|
||||
double eta;
|
||||
float percentage;
|
||||
long totalParams;
|
||||
// Структура для описания слоя
|
||||
struct LayerStructure_t {
|
||||
int size; // Количество нейронов в слое
|
||||
std::vector<int> sources; // Индексы слоев, которые являются входными для данного
|
||||
std::vector<int> sourceBranches; // Ветка для каждого источника (-1 если не разделен, 0 или 1)
|
||||
int branch; // -1 если не разделен, 0 или 1 если слой работает с конкретной веткой
|
||||
bool isSplit; // true если слой разделяет выход на две ветки
|
||||
|
||||
LayerStructure_t() : size(0), branch(-1), isSplit(false) {}
|
||||
LayerStructure_t(int s, const std::vector<int>& src = {})
|
||||
: size(s), sources(src), branch(-1), isSplit(false) {
|
||||
sourceBranches.resize(src.size(), -1);
|
||||
}
|
||||
};
|
||||
|
||||
class Tokenizer;
|
||||
class Embedder;
|
||||
// Параметры для передачи в шейдер через Push Constants
|
||||
struct TrainParams {
|
||||
uint32_t mode; // 0: FF, 1: OutError, 2: BackProp, 3: Update
|
||||
uint32_t prevSize; // Суммарный размер входов
|
||||
uint32_t nextSize; // Размер текущего слоя
|
||||
uint32_t wOff; // Смещение весов в буфере W
|
||||
uint32_t bOff; // Смещение смещений в буфере B
|
||||
uint32_t oOff; // Смещение входов в буфере O
|
||||
uint32_t nextOOff; // Смещение выходов в буфере O
|
||||
float lr; // Скорость обучения
|
||||
};
|
||||
|
||||
// Структура для обратной связи о процессе обучения
|
||||
struct TrainStatus {
|
||||
int epoch = 0;
|
||||
int totalEpochs = 0;
|
||||
int step = 0;
|
||||
int totalSteps = 0;
|
||||
double loss = 0.0;
|
||||
double totalLoss = 0.0;
|
||||
double lastEpochLoss = 0.0;
|
||||
double speed = 0.0;
|
||||
double eta = 0.0;
|
||||
float progress = 0.0f;
|
||||
long long params = 0;
|
||||
};
|
||||
|
||||
class NeuralNetwork {
|
||||
private:
|
||||
int numLayers;
|
||||
std::vector<int> sizes;
|
||||
std::vector<float> h_weights, h_biases, h_outputs, h_errors;
|
||||
std::vector<uint32_t> wOff, bOff, oOff;
|
||||
public:
|
||||
NeuralNetwork(LayerStructure_t layers[], int count, bool useVulkanParam);
|
||||
~NeuralNetwork();
|
||||
|
||||
double train(const std::map<int, std::vector<double>>& inputs,
|
||||
const std::vector<double>& target, double lr);
|
||||
std::vector<double> feedForward(const std::map<int, std::vector<double>>& inputs);
|
||||
|
||||
void syncToCPU();
|
||||
void syncToGPU();
|
||||
|
||||
// Геттеры для UI
|
||||
const std::vector<int>& getSizes() const { return sizes; }
|
||||
const std::vector<std::vector<int>>& getLayerSources() const { return layerSources; }
|
||||
const std::vector<std::vector<int>>& getSourceBranches() const { return layerSourceBranches; }
|
||||
const std::vector<int>& getBranches() const { return branches; }
|
||||
const std::vector<bool>& getSplits() const { return splits; }
|
||||
|
||||
private:
|
||||
bool useVulkan;
|
||||
int numLayers;
|
||||
|
||||
// Метаданные архитектуры
|
||||
std::vector<int> sizes;
|
||||
std::vector<uint32_t> wOff;
|
||||
std::vector<uint32_t> bOff;
|
||||
std::vector<uint32_t> oOff;
|
||||
std::vector<std::vector<int>> layerSources;
|
||||
std::vector<std::vector<int>> layerSourceBranches; // Ветка для каждого источника
|
||||
std::vector<int> branches; // Ветка каждого слоя (-1, 0, 1)
|
||||
std::vector<bool> splits; // Является ли слой split
|
||||
|
||||
// Данные на стороне CPU
|
||||
std::vector<float> h_weights;
|
||||
std::vector<float> h_biases;
|
||||
std::vector<float> h_outputs;
|
||||
std::vector<float> h_errors;
|
||||
|
||||
// Vulkan объекты
|
||||
vk::Instance instance;
|
||||
vk::PhysicalDevice physDev;
|
||||
vk::Device device;
|
||||
vk::Queue queue;
|
||||
vk::CommandPool cmdPool;
|
||||
uint32_t computeQueueFamilyIndex;
|
||||
|
||||
vk::Buffer gpuW, gpuB, gpuO, gpuE, gpuT;
|
||||
vk::DeviceMemory memW, memB, memO, memE, memT;
|
||||
void *pW = nullptr, *pB = nullptr, *pO = nullptr, *pE = nullptr, *pT = nullptr;
|
||||
|
||||
|
||||
vk::DescriptorSetLayout dsLayout;
|
||||
vk::DescriptorPool descriptorPool;
|
||||
vk::DescriptorSet descriptorSet;
|
||||
vk::DescriptorSetLayout dsLayout;
|
||||
vk::PipelineLayout pipeLayout;
|
||||
vk::Pipeline pipeline;
|
||||
vk::ShaderModule shaderModule;
|
||||
|
||||
struct TrainParams {
|
||||
uint32_t mode;
|
||||
uint32_t prevSize;
|
||||
uint32_t nextSize;
|
||||
uint32_t wOff;
|
||||
uint32_t bOff;
|
||||
uint32_t oOff;
|
||||
uint32_t nextOOff;
|
||||
float lr;
|
||||
};
|
||||
vk::Buffer gpuW, gpuB, gpuO, gpuE, gpuT;
|
||||
vk::DeviceMemory memW, memB, memO, memE, memT;
|
||||
|
||||
void *pW = nullptr, *pB = nullptr, *pO = nullptr, *pE = nullptr, *pT = nullptr;
|
||||
|
||||
// Command buffer pooling
|
||||
std::vector<vk::CommandBuffer> cmdBuffers;
|
||||
uint32_t currentCmdBuffer = 0;
|
||||
|
||||
void initVulkan();
|
||||
void initVulkanResources();
|
||||
uint32_t findMemoryType(uint32_t typeFilter, vk::MemoryPropertyFlags properties);
|
||||
uint32_t findMemoryType(uint32_t f, vk::MemoryPropertyFlags p);
|
||||
std::vector<char> readFile(const std::string& filename);
|
||||
double runTrainCPU(const std::vector<double>& input, const std::vector<double>& target, double lr);
|
||||
|
||||
public:
|
||||
int cpu_count = 4;
|
||||
NeuralNetwork(LayerStructure_t layers[], int count, bool useVulkan = false);
|
||||
~NeuralNetwork();
|
||||
|
||||
void syncToCPU();
|
||||
void syncToGPU();
|
||||
|
||||
std::vector<double> feedForward(const std::vector<double>& input);
|
||||
double train(const std::vector<double>& input, const std::vector<double>& target, double lr);
|
||||
|
||||
void trainOnSequence(
|
||||
Tokenizer& tok,
|
||||
Embedder& emb,
|
||||
const std::string& dataset,
|
||||
int epochs,
|
||||
double lr,
|
||||
std::function<std::vector<double>(const std::vector<int>&, Embedder&)> buildInput,
|
||||
std::function<void(const TrainStatus&)> onProgress = nullptr
|
||||
);
|
||||
|
||||
long long getTotalParameters() {
|
||||
long long total = 0;
|
||||
for (int i = 0; i < numLayers - 1; i++) {
|
||||
total += (long long)sizes[i] * sizes[i+1];
|
||||
total += (long long)sizes[i+1];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
int calculateTotalInputSize(int layerIdx) const;
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // CORE_HPP
|
||||
@@ -0,0 +1,513 @@
|
||||
#include "node_editor.hpp"
|
||||
#include "core.hpp"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace NodeEditor {
|
||||
|
||||
// === Вспомогательные функции ===
|
||||
|
||||
ImVec2 GetPortPos(const Node& node, const Port& port, const ImVec2& canvasOffset) {
|
||||
float portY = node.pos.y + 30 + port.index * 25;
|
||||
if (port.type == PortType::Input) {
|
||||
return ImVec2(node.pos.x + canvasOffset.x, portY + canvasOffset.y);
|
||||
} else {
|
||||
return ImVec2(node.pos.x + node.size.x + canvasOffset.x, portY + canvasOffset.y);
|
||||
}
|
||||
}
|
||||
|
||||
void DrawBezier(ImDrawList* dl, ImVec2 start, ImVec2 end, ImU32 color, float thickness) {
|
||||
ImVec2 ctrl1 = start + ImVec2(50, 0);
|
||||
ImVec2 ctrl2 = end - ImVec2(50, 0);
|
||||
dl->AddBezierCubic(start, ctrl1, ctrl2, end, color, thickness, 32);
|
||||
|
||||
// Стрелочка на конце
|
||||
float angle = atan2(end.y - ctrl2.y, end.x - ctrl2.x);
|
||||
ImVec2 arrow1 = end + ImVec2(cos(angle - 0.5f) * 8, sin(angle - 0.5f) * 8);
|
||||
ImVec2 arrow2 = end + ImVec2(cos(angle + 0.5f) * 8, sin(angle + 0.5f) * 8);
|
||||
dl->AddLine(end, arrow1, color, thickness);
|
||||
dl->AddLine(end, arrow2, color, thickness);
|
||||
}
|
||||
|
||||
ImU32 GetNodeColor(NodeType type, bool selected) {
|
||||
if (selected) return IM_COL32(255, 255, 100, 255);
|
||||
switch(type) {
|
||||
case NodeType::Input: return IM_COL32(100, 200, 100, 255);
|
||||
case NodeType::Hidden: return IM_COL32(100, 150, 255, 255);
|
||||
case NodeType::Output: return IM_COL32(255, 100, 100, 255);
|
||||
default: return IM_COL32(150, 150, 150, 255);
|
||||
}
|
||||
}
|
||||
|
||||
// === Методы Node ===
|
||||
|
||||
ImVec2 Node::GetInputPos(int portIdx) const {
|
||||
return ImVec2(pos.x, pos.y + 35 + portIdx * 25);
|
||||
}
|
||||
|
||||
ImVec2 Node::GetOutputPos(int portIdx) const {
|
||||
return ImVec2(pos.x + size.x, pos.y + 35 + portIdx * 25);
|
||||
}
|
||||
|
||||
// === Инициализация ===
|
||||
|
||||
void Init(GraphState& graph) {
|
||||
graph.nextNodeId = 0;
|
||||
graph.zoom = 1.0f;
|
||||
graph.panOffset = ImVec2(100, 50);
|
||||
}
|
||||
|
||||
// === Отрисовка узла ===
|
||||
|
||||
void DrawNode(GraphState& graph, Node& node, const ImVec2& canvasOffset) {
|
||||
ImDrawList* dl = ImGui::GetWindowDrawList();
|
||||
ImVec2 pos = node.pos + canvasOffset;
|
||||
|
||||
// Фон узла
|
||||
ImU32 bgColor = GetNodeColor(node.type, node.selected);
|
||||
dl->AddRectFilled(pos, pos + node.size, IM_COL32(40, 45, 55, 255), 8);
|
||||
dl->AddRect(pos, pos + node.size, bgColor, 8, 0, 2.0f);
|
||||
|
||||
// Заголовок
|
||||
ImVec2 titlePos = pos + ImVec2(10, 5);
|
||||
dl->AddText(titlePos, IM_COL32(255, 255, 255, 255), node.title.c_str());
|
||||
|
||||
// Размер слоя
|
||||
std::string sizeText = std::to_string(node.layerSize) + " нейронов";
|
||||
dl->AddText(pos + ImVec2(10, 22), IM_COL32(180, 180, 180, 200), sizeText.c_str());
|
||||
|
||||
// Ветка (если есть)
|
||||
if (node.branch != -1) {
|
||||
const char* branchName = node.branch == 0 ? "A" : "B";
|
||||
ImU32 branchColor = node.branch == 0 ? IM_COL32(100, 255, 100, 255) : IM_COL32(100, 150, 255, 255);
|
||||
dl->AddRectFilled(pos + ImVec2(node.size.x - 35, 3),
|
||||
pos + ImVec2(node.size.x - 5, 18),
|
||||
branchColor, 3);
|
||||
dl->AddText(pos + ImVec2(node.size.x - 28, 5), IM_COL32(0,0,0,255), branchName);
|
||||
}
|
||||
|
||||
// Порты ввода
|
||||
for (size_t i = 0; i < node.inputs.size(); i++) {
|
||||
Port& port = node.inputs[i];
|
||||
ImVec2 portPos = GetPortPos(node, port, canvasOffset);
|
||||
|
||||
// Кружок порта
|
||||
ImU32 portColor = IM_COL32(180, 180, 180, 255);
|
||||
if (graph.hoveredPortNode == node.id && graph.hoveredPortIdx == (int)i &&
|
||||
graph.hoveredPortType == PortType::Input) {
|
||||
portColor = IM_COL32(255, 255, 100, 255);
|
||||
}
|
||||
dl->AddCircleFilled(portPos, 6, portColor);
|
||||
dl->AddCircle(portPos, 6, IM_COL32(50, 50, 50, 255), 12, 1.5f);
|
||||
|
||||
// Название порта
|
||||
dl->AddText(portPos + ImVec2(12, -6), IM_COL32(200, 200, 200, 255), port.name.c_str());
|
||||
|
||||
// Выбор ветки для порта
|
||||
if (port.isBranchPort) {
|
||||
ImVec2 btnPos = portPos + ImVec2(100, 0);
|
||||
if (ImGui::InvisibleButton(("##branch" + std::to_string(node.id) + "_" + std::to_string(i)).c_str(), ImVec2(50, 15))) {
|
||||
node.branch = (node.branch + 1) % 3 - 1; // -1 -> 0 -> 1 -> -1
|
||||
}
|
||||
const char* branchTxt = node.branch == -1 ? "All" : (node.branch == 0 ? "A" : "B");
|
||||
dl->AddText(btnPos, IM_COL32(255, 255, 255, 200), branchTxt);
|
||||
}
|
||||
}
|
||||
|
||||
// Порты вывода
|
||||
for (size_t i = 0; i < node.outputs.size(); i++) {
|
||||
Port& port = node.outputs[i];
|
||||
ImVec2 portPos = GetPortPos(node, port, canvasOffset);
|
||||
|
||||
ImU32 portColor = IM_COL32(180, 180, 180, 255);
|
||||
if (graph.hoveredPortNode == node.id && graph.hoveredPortIdx == (int)i &&
|
||||
graph.hoveredPortType == PortType::Output) {
|
||||
portColor = IM_COL32(255, 255, 100, 255);
|
||||
}
|
||||
dl->AddCircleFilled(portPos, 6, portColor);
|
||||
dl->AddCircle(portPos, 6, IM_COL32(50, 50, 50, 255), 12, 1.5f);
|
||||
|
||||
// Название справа от порта
|
||||
ImVec2 textPos = portPos - ImVec2(12 + ImGui::CalcTextSize(port.name.c_str()).x, 6);
|
||||
dl->AddText(textPos, IM_COL32(200, 200, 200, 255), port.name.c_str());
|
||||
}
|
||||
|
||||
// Индикатор разделения
|
||||
if (node.isSplit) {
|
||||
dl->AddText(pos + ImVec2(10, node.size.y - 20), IM_COL32(255, 200, 100, 255), "✦ Split Output");
|
||||
}
|
||||
}
|
||||
|
||||
// === Отрисовка соединений ===
|
||||
|
||||
void DrawConnections(GraphState& graph, const ImVec2& canvasOffset) {
|
||||
ImDrawList* dl = ImGui::GetWindowDrawList();
|
||||
|
||||
for (const auto& conn : graph.connections) {
|
||||
auto fromNodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[conn](const Node& n) { return n.id == conn.fromNode; });
|
||||
auto toNodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[conn](const Node& n) { return n.id == conn.toNode; });
|
||||
|
||||
if (fromNodeIt == graph.nodes.end() || toNodeIt == graph.nodes.end()) continue;
|
||||
|
||||
const Node& from = *fromNodeIt;
|
||||
const Node& to = *toNodeIt;
|
||||
|
||||
ImVec2 start = GetPortPos(from, from.outputs[conn.fromPort], canvasOffset);
|
||||
ImVec2 end = GetPortPos(to, to.inputs[conn.toPort], canvasOffset);
|
||||
|
||||
// Цвет в зависимости от ветки
|
||||
ImU32 color = IM_COL32(200, 200, 200, 180);
|
||||
if (to.branch == 0) color = IM_COL32(100, 255, 100, 180);
|
||||
else if (to.branch == 1) color = IM_COL32(100, 150, 255, 180);
|
||||
|
||||
DrawBezier(dl, start, end, color);
|
||||
}
|
||||
|
||||
// Линия при создании соединения
|
||||
if (graph.creatingConnection) {
|
||||
auto nodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[graph](const Node& n) { return n.id == graph.connectionStartNode; });
|
||||
if (nodeIt != graph.nodes.end()) {
|
||||
const Node& startNode = *nodeIt;
|
||||
ImVec2 start = graph.connectionStartType == PortType::Output
|
||||
? startNode.GetOutputPos(graph.connectionStartPort)
|
||||
: startNode.GetInputPos(graph.connectionStartPort);
|
||||
start += canvasOffset;
|
||||
|
||||
DrawBezier(dl, start, graph.connectionMousePos + canvasOffset,
|
||||
IM_COL32(255, 255, 100, 200), 2.5f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Обработка ввода ===
|
||||
|
||||
void HandleInput(GraphState& graph, const ImVec2& canvasPos) {
|
||||
ImGuiIO& io = ImGui::GetIO();
|
||||
ImVec2 mousePos = io.MousePos - canvasPos - graph.panOffset;
|
||||
|
||||
// Панорамирование (средняя кнопка мыши)
|
||||
if (ImGui::IsMouseDown(2)) {
|
||||
if (!graph.panning) {
|
||||
graph.panning = true;
|
||||
graph.panStart = io.MousePos - graph.panOffset;
|
||||
}
|
||||
graph.panOffset = io.MousePos - graph.panStart;
|
||||
} else {
|
||||
graph.panning = false;
|
||||
}
|
||||
|
||||
// Масштаб колесом
|
||||
if (ImGui::IsWindowHovered() && io.MouseWheel != 0) {
|
||||
float zoomDelta = io.MouseWheel * 0.1f;
|
||||
graph.zoom = ImClamp(graph.zoom + zoomDelta, 0.5f, 3.0f);
|
||||
}
|
||||
|
||||
// Проверка ховера над портами
|
||||
graph.hoveredPortNode = -1;
|
||||
for (auto& node : graph.nodes) {
|
||||
for (size_t i = 0; i < node.inputs.size(); i++) {
|
||||
ImVec2 pos = GetPortPos(node, node.inputs[i], graph.panOffset);
|
||||
if (ImLengthSqr(mousePos - pos) < 100) {
|
||||
graph.hoveredPortNode = node.id;
|
||||
graph.hoveredPortIdx = (int)i;
|
||||
graph.hoveredPortType = PortType::Input;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < node.outputs.size(); i++) {
|
||||
ImVec2 pos = GetPortPos(node, node.outputs[i], graph.panOffset);
|
||||
if (ImLengthSqr(mousePos - pos) < 100) {
|
||||
graph.hoveredPortNode = node.id;
|
||||
graph.hoveredPortIdx = (int)i;
|
||||
graph.hoveredPortType = PortType::Output;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Создание соединения
|
||||
if (ImGui::IsMouseClicked(0) && graph.hoveredPortNode != -1) {
|
||||
graph.creatingConnection = true;
|
||||
graph.connectionStartNode = graph.hoveredPortNode;
|
||||
graph.connectionStartPort = graph.hoveredPortIdx;
|
||||
graph.connectionStartType = graph.hoveredPortType;
|
||||
}
|
||||
|
||||
if (graph.creatingConnection) {
|
||||
graph.connectionMousePos = mousePos;
|
||||
|
||||
if (ImGui::IsMouseReleased(0)) {
|
||||
// Завершение соединения
|
||||
if (graph.hoveredPortNode != -1 &&
|
||||
graph.hoveredPortNode != graph.connectionStartNode &&
|
||||
graph.hoveredPortType != graph.connectionStartType) {
|
||||
|
||||
// Добавляем соединение (Output -> Input)
|
||||
if (graph.connectionStartType == PortType::Output) {
|
||||
graph.connections.emplace_back(
|
||||
graph.connectionStartNode, graph.connectionStartPort,
|
||||
graph.hoveredPortNode, graph.hoveredPortIdx);
|
||||
} else {
|
||||
graph.connections.emplace_back(
|
||||
graph.hoveredPortNode, graph.hoveredPortIdx,
|
||||
graph.connectionStartNode, graph.connectionStartPort);
|
||||
}
|
||||
}
|
||||
graph.creatingConnection = false;
|
||||
}
|
||||
|
||||
// Отмена правой кнопкой
|
||||
if (ImGui::IsMouseClicked(1)) {
|
||||
graph.creatingConnection = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Перетаскивание узлов
|
||||
for (auto& node : graph.nodes) {
|
||||
ImVec2 nodeScreenPos = node.pos + graph.panOffset;
|
||||
if (ImRect(nodeScreenPos, nodeScreenPos + node.size).Contains(io.MousePos - graph.panOffset)) {
|
||||
if (ImGui::IsMouseClicked(0) && !graph.creatingConnection) {
|
||||
node.selected = true;
|
||||
node.dragging = true;
|
||||
graph.selectedNode = node.id;
|
||||
node.dragOffset = io.MousePos - node.pos - graph.panOffset;
|
||||
}
|
||||
}
|
||||
|
||||
if (node.dragging) {
|
||||
node.pos = io.MousePos - node.dragOffset - graph.panOffset;
|
||||
if (ImGui::IsMouseReleased(0)) {
|
||||
node.dragging = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Выбор узла по клику на пустом месте
|
||||
if (ImGui::IsMouseClicked(0) && graph.hoveredPortNode == -1 && !graph.creatingConnection) {
|
||||
bool clickedOnNode = false;
|
||||
for (const auto& node : graph.nodes) {
|
||||
if (ImRect(node.pos + graph.panOffset, node.pos + graph.panOffset + node.size)
|
||||
.Contains(io.MousePos)) {
|
||||
clickedOnNode = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!clickedOnNode) {
|
||||
graph.selectedNode = -1;
|
||||
for (auto& node : graph.nodes) node.selected = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Удаление соединения по правому клику
|
||||
if (ImGui::IsMouseClicked(1) && !graph.creatingConnection) {
|
||||
for (auto it = graph.connections.begin(); it != graph.connections.end(); ) {
|
||||
// Проверка ховера над линией (упрощенная)
|
||||
auto fromIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[it](const Node& n) { return n.id == it->fromNode; });
|
||||
auto toIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[it](const Node& n) { return n.id == it->toNode; });
|
||||
|
||||
if (fromIt != graph.nodes.end() && toIt != graph.nodes.end()) {
|
||||
ImVec2 start = GetPortPos(*fromIt, fromIt->outputs[it->fromPort], graph.panOffset);
|
||||
ImVec2 end = GetPortPos(*toIt, toIt->inputs[it->toPort], graph.panOffset);
|
||||
|
||||
// Простая проверка расстояния до линии
|
||||
float dist = ImLineClosestPoint(start, end, mousePos + graph.panOffset);
|
||||
if (ImLengthSqr((mousePos + graph.panOffset) - dist) < 25) {
|
||||
it = graph.connections.erase(it);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Основная отрисовка ===
|
||||
|
||||
void DrawGraph(GraphState& graph, const ImVec2& canvasSize) {
|
||||
ImDrawList* dl = ImGui::GetWindowDrawList();
|
||||
ImVec2 canvasPos = ImGui::GetCursorScreenPos();
|
||||
|
||||
// Фон с сеткой
|
||||
dl->AddRectFilled(canvasPos, canvasPos + canvasSize, IM_COL32(30, 32, 40, 255));
|
||||
|
||||
// Сетка
|
||||
float gridSize = 50 * graph.zoom;
|
||||
for (float x = fmodf(-graph.panOffset.x, gridSize); x < canvasSize.x; x += gridSize) {
|
||||
dl->AddLine(canvasPos + ImVec2(x, 0), canvasPos + ImVec2(x, canvasSize.y),
|
||||
IM_COL32(50, 55, 70, 100));
|
||||
}
|
||||
for (float y = fmodf(-graph.panOffset.y, gridSize); y < canvasSize.y; y += gridSize) {
|
||||
dl->AddLine(canvasPos + ImVec2(0, y), canvasPos + ImVec2(canvasSize.x, y),
|
||||
IM_COL32(50, 55, 70, 100));
|
||||
}
|
||||
|
||||
// Соединения (сначала, чтобы были под узлами)
|
||||
DrawConnections(graph, graph.panOffset);
|
||||
|
||||
// Узлы
|
||||
for (auto& node : graph.nodes) {
|
||||
DrawNode(graph, node, graph.panOffset);
|
||||
}
|
||||
|
||||
// Обработка ввода
|
||||
ImGui::InvisibleButton("##GraphCanvas", canvasSize);
|
||||
if (ImGui::IsItemHovered()) {
|
||||
HandleInput(graph, canvasPos);
|
||||
}
|
||||
|
||||
// Контекстное меню для добавления узлов
|
||||
if (ImGui::BeginPopupContextItem("##GraphContext")) {
|
||||
if (ImGui::MenuItem("➕ Входной слой")) {
|
||||
Node newNode(graph.nextNodeId++, "Input", NodeType::Input);
|
||||
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
|
||||
newNode.size = ImVec2(180, 90);
|
||||
newNode.inputs = {};
|
||||
newNode.outputs = {Port("Output", PortType::Output)};
|
||||
newNode.layerSize = 256 * 8; // CONTEXT * EMBED
|
||||
graph.nodes.push_back(newNode);
|
||||
}
|
||||
if (ImGui::MenuItem("⬜ Скрытый слой")) {
|
||||
Node newNode(graph.nextNodeId++, "Hidden", NodeType::Hidden);
|
||||
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
|
||||
newNode.size = ImVec2(180, 100);
|
||||
newNode.inputs = {Port("Input", PortType::Input)};
|
||||
newNode.outputs = {Port("Output", PortType::Output)};
|
||||
graph.nodes.push_back(newNode);
|
||||
}
|
||||
if (ImGui::MenuItem("🔴 Выходной слой")) {
|
||||
Node newNode(graph.nextNodeId++, "Output", NodeType::Output);
|
||||
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
|
||||
newNode.size = ImVec2(180, 90);
|
||||
newNode.inputs = {Port("Input", PortType::Input)};
|
||||
newNode.outputs = {};
|
||||
newNode.layerSize = 300; // VOCAB
|
||||
graph.nodes.push_back(newNode);
|
||||
}
|
||||
ImGui::Separator();
|
||||
ImGui::Text("Управление:");
|
||||
ImGui::Text("• ЛКМ: перетащить узел / создать связь");
|
||||
ImGui::Text("• ПКМ: удалить связь / отмена");
|
||||
ImGui::Text("• Колесо: масштаб");
|
||||
ImGui::Text("• Средняя кнопка: панорамирование");
|
||||
ImGui::EndPopup();
|
||||
}
|
||||
|
||||
// Панель свойств выбранного узла
|
||||
if (graph.selectedNode != -1) {
|
||||
auto selIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
|
||||
[graph](const Node& n) { return n.id == graph.selectedNode; });
|
||||
if (selIt != graph.nodes.end()) {
|
||||
Node& sel = *selIt;
|
||||
ImGui::SetNextWindowPos(canvasPos + ImVec2(10, 10));
|
||||
ImGui::SetNextWindowSize(ImVec2(250, 200));
|
||||
if (ImGui::Begin("##NodeProperties", nullptr,
|
||||
ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_AlwaysAutoResize |
|
||||
ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoSavedSettings)) {
|
||||
ImGui::TextColored(ImVec4(1,1,0,1), "Свойства: %s", sel.title.c_str());
|
||||
ImGui::Separator();
|
||||
|
||||
if (sel.type != NodeType::Output) {
|
||||
ImGui::InputInt("Нейронов", &sel.layerSize);
|
||||
if (sel.layerSize < 1) sel.layerSize = 1;
|
||||
}
|
||||
|
||||
if (sel.type == NodeType::Input) {
|
||||
ImGui::Text("Ветка:");
|
||||
if (ImGui::RadioButton("Объединенная", sel.branch == -1)) sel.branch = -1;
|
||||
if (ImGui::RadioButton("Ветка A", sel.branch == 0)) sel.branch = 0;
|
||||
if (ImGui::RadioButton("Ветка B", sel.branch == 1)) sel.branch = 1;
|
||||
}
|
||||
|
||||
ImGui::Checkbox("Разделить выход", &sel.isSplit);
|
||||
|
||||
if (ImGui::Button("🗑 Удалить узел", ImVec2(-1, 30))) {
|
||||
// Удаляем узел и все его соединения
|
||||
graph.connections.erase(
|
||||
std::remove_if(graph.connections.begin(), graph.connections.end(),
|
||||
[id = sel.id](const Connection& c) {
|
||||
return c.fromNode == id || c.toNode == id;
|
||||
}),
|
||||
graph.connections.end());
|
||||
graph.nodes.erase(selIt);
|
||||
graph.selectedNode = -1;
|
||||
}
|
||||
ImGui::End();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Синхронизация с LayerStructure ===
|
||||
|
||||
void SyncToLayerConfigs(GraphState& graph, std::vector<LayerStructure_t>& configs) {
|
||||
configs.clear();
|
||||
|
||||
// Сортируем узлы по позиции (приблизительный топологический порядок)
|
||||
std::vector<Node*> sortedNodes;
|
||||
for (auto& n : graph.nodes) sortedNodes.push_back(&n);
|
||||
std::sort(sortedNodes.begin(), sortedNodes.end(),
|
||||
[](Node* a, Node* b) { return a->pos.x < b->pos.x; });
|
||||
|
||||
for (auto* node : sortedNodes) {
|
||||
LayerStructure_t layer;
|
||||
layer.size = node->layerSize;
|
||||
layer.branch = node->branch;
|
||||
layer.isSplit = node->isSplit;
|
||||
|
||||
// Находим источники по соединениям
|
||||
for (const auto& conn : graph.connections) {
|
||||
if (conn.toNode == node->id) {
|
||||
// Находим индекс слоя-источника в configs
|
||||
for (int i = 0; i < (int)configs.size(); i++) {
|
||||
// Упрощенная логика - в реальности нужен маппинг node.id -> layer index
|
||||
if (i == conn.fromNode) {
|
||||
layer.sources.push_back(i);
|
||||
layer.sourceBranches.push_back(node->branch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
configs.push_back(layer);
|
||||
}
|
||||
}
|
||||
|
||||
void SyncFromLayerConfigs(GraphState& graph, const std::vector<LayerStructure_t>& configs) {
|
||||
graph.nodes.clear();
|
||||
graph.connections.clear();
|
||||
|
||||
for (size_t i = 0; i < configs.size(); i++) {
|
||||
const auto& cfg = configs[i];
|
||||
NodeType type = cfg.sources.empty() ? NodeType::Input
|
||||
: (i == configs.size()-1 ? NodeType::Output : NodeType::Hidden);
|
||||
|
||||
Node node((int)i, type == NodeType::Input ? "Input" :
|
||||
type == NodeType::Output ? "Output" : "Hidden", type);
|
||||
node.pos = ImVec2(100 + i * 250, 100 + (i % 3) * 150);
|
||||
node.size = ImVec2(180, type == NodeType::Hidden ? 100 : 90);
|
||||
node.layerSize = cfg.size;
|
||||
node.branch = cfg.branch;
|
||||
node.isSplit = cfg.isSplit;
|
||||
node.layerIndex = (int)i;
|
||||
|
||||
if (type != NodeType::Output)
|
||||
node.outputs.push_back(Port("Out", PortType::Output));
|
||||
if (type != NodeType::Input)
|
||||
node.inputs.push_back(Port("In", PortType::Input));
|
||||
|
||||
graph.nodes.push_back(node);
|
||||
}
|
||||
|
||||
// Восстанавливаем соединения
|
||||
for (size_t i = 0; i < configs.size(); i++) {
|
||||
for (size_t j = 0; j < configs[i].sources.size(); j++) {
|
||||
int srcIdx = configs[i].sources[j];
|
||||
graph.connections.emplace_back(srcIdx, 0, (int)i, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace NodeEditor
|
||||
@@ -0,0 +1,115 @@
|
||||
#ifndef NODE_EDITOR_HPP
|
||||
#define NODE_EDITOR_HPP
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "imgui.h"
|
||||
|
||||
namespace NodeEditor {
|
||||
|
||||
// Типы портов
|
||||
enum class PortType { Input, Output };
|
||||
enum class NodeType { Input, Hidden, Output };
|
||||
|
||||
// Порт узла
|
||||
struct Port {
|
||||
std::string name;
|
||||
PortType type;
|
||||
int index; // Индекс для множественных портов
|
||||
bool isBranchPort; // true если это порт для выбора ветки (0/1)
|
||||
|
||||
Port(const std::string& n, PortType t, int idx = 0, bool branch = false)
|
||||
: name(n), type(t), index(idx), isBranchPort(branch) {}
|
||||
};
|
||||
|
||||
// Узел графа
|
||||
struct Node {
|
||||
int id;
|
||||
std::string title;
|
||||
NodeType type;
|
||||
ImVec2 pos;
|
||||
ImVec2 size;
|
||||
bool selected;
|
||||
bool dragging;
|
||||
ImVec2 dragOffset;
|
||||
|
||||
// Данные слоя
|
||||
int layerSize;
|
||||
int layerIndex; // Индекс в ui.layerConfigs
|
||||
std::vector<int> connectedInputs; // IDs узлов, подключенных к входам
|
||||
std::vector<int> connectedOutputs; // IDs узлов, подключенных к выходам
|
||||
|
||||
// Порты
|
||||
std::vector<Port> inputs;
|
||||
std::vector<Port> outputs;
|
||||
|
||||
// Для ветвления
|
||||
int branch; // -1, 0, 1
|
||||
bool isSplit; // Разделяет ли выход на две ветки
|
||||
|
||||
Node(int id_, const std::string& title_, NodeType type_)
|
||||
: id(id_), title(title_), type(type_), pos(0,0), size(200,100),
|
||||
selected(false), dragging(false), layerSize(128), layerIndex(-1),
|
||||
branch(-1), isSplit(false) {}
|
||||
|
||||
ImVec2 GetInputPos(int portIdx) const;
|
||||
ImVec2 GetOutputPos(int portIdx) const;
|
||||
};
|
||||
|
||||
// Соединение между портами
|
||||
struct Connection {
|
||||
int fromNode;
|
||||
int fromPort;
|
||||
int toNode;
|
||||
int toPort;
|
||||
|
||||
Connection(int fn, int fp, int tn, int tp)
|
||||
: fromNode(fn), fromPort(fp), toNode(tn), toPort(tp) {}
|
||||
};
|
||||
|
||||
// Состояние редактора
|
||||
struct GraphState {
|
||||
std::vector<Node> nodes;
|
||||
std::vector<Connection> connections;
|
||||
|
||||
int nextNodeId = 0;
|
||||
int selectedNode = -1;
|
||||
int hoveredPortNode = -1;
|
||||
int hoveredPortIdx = -1;
|
||||
PortType hoveredPortType = PortType::Input;
|
||||
|
||||
// Для создания соединения
|
||||
bool creatingConnection = false;
|
||||
int connectionStartNode = -1;
|
||||
int connectionStartPort = -1;
|
||||
PortType connectionStartType = PortType::Output;
|
||||
ImVec2 connectionMousePos;
|
||||
|
||||
// Масштаб и панорамирование
|
||||
float zoom = 1.0f;
|
||||
ImVec2 panOffset;
|
||||
bool panning;
|
||||
ImVec2 panStart;
|
||||
|
||||
GraphState() : panning(false) {}
|
||||
};
|
||||
|
||||
// === API ===
|
||||
|
||||
void Init(GraphState& graph);
|
||||
void DrawGraph(GraphState& graph, const ImVec2& canvasSize);
|
||||
void HandleInput(GraphState& graph, const ImVec2& canvasPos);
|
||||
|
||||
// Синхронизация с LayerStructure_t
|
||||
void SyncToLayerConfigs(GraphState& graph, std::vector<LayerStructure_t>& configs);
|
||||
void SyncFromLayerConfigs(GraphState& graph, const std::vector<LayerStructure_t>& configs);
|
||||
|
||||
// Вспомогательные функции
|
||||
ImVec2 GetPortPos(const Node& node, const Port& port, const ImVec2& canvasOffset);
|
||||
void DrawBezier(ImDrawList* dl, ImVec2 start, ImVec2 end, ImU32 color, float thickness = 2.0f);
|
||||
ImU32 GetNodeColor(NodeType type, bool selected);
|
||||
|
||||
} // namespace NodeEditor
|
||||
|
||||
#endif // NODE_EDITOR_HPP
|
||||
+8
-7
@@ -10,12 +10,12 @@ layout(std430, binding = 4) buffer Targets { float T[]; };
|
||||
|
||||
layout(push_constant) uniform Params {
|
||||
uint mode; // 0: FF, 1: OutError, 2: BackProp, 3: Update
|
||||
uint prevSize;
|
||||
uint nextSize;
|
||||
uint prevSize; // Суммарный размер всех входных слоев
|
||||
uint nextSize; // Размер текущего слоя
|
||||
uint wOff;
|
||||
uint bOff;
|
||||
uint oOff;
|
||||
uint nextOOff;
|
||||
uint oOff; // Смещение первого входного слоя
|
||||
uint nextOOff; // Смещение текущего слоя
|
||||
float lr;
|
||||
} p;
|
||||
|
||||
@@ -35,24 +35,25 @@ void main() {
|
||||
O[p.nextOOff + idx] = sigmoid(sum);
|
||||
}
|
||||
}
|
||||
// MODE 1: Ошибка выходного слоя
|
||||
// MODE 1: Ошибка выходного слоя (MSE derivative)
|
||||
else if (p.mode == 1) {
|
||||
if (idx < p.nextSize) {
|
||||
float outVal = O[p.nextOOff + idx];
|
||||
E[p.nextOOff + idx] = (T[idx] - outVal) * dSigmoid(outVal);
|
||||
}
|
||||
}
|
||||
// MODE 2: Обратное распространение ошибки (Hidden layers)
|
||||
// MODE 2: Обратное распространение ошибки (Error propagation)
|
||||
else if (p.mode == 2) {
|
||||
if (idx < p.prevSize) {
|
||||
float errSum = 0.0;
|
||||
for (uint i = 0; i < p.nextSize; i++) {
|
||||
errSum += E[p.nextOOff + i] * W[p.wOff + i * p.prevSize + idx];
|
||||
}
|
||||
// Ошибка записывается в Errors входного слоя
|
||||
E[p.oOff + idx] = errSum * dSigmoid(O[p.oOff + idx]);
|
||||
}
|
||||
}
|
||||
// MODE 3: Обновление весов и смещений
|
||||
// MODE 3: Обновление весов и смещений (Gradient Descent)
|
||||
else if (p.mode == 3) {
|
||||
if (idx < p.nextSize) {
|
||||
float errTerm = E[p.nextOOff + idx] * p.lr;
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
#ifndef TYPEDEF_H
|
||||
#define TYPEDEF_H
|
||||
|
||||
const int MAX_CONTEXT = 256; // Сколько токенов видит сеть
|
||||
const int EMBED_DIM = 8; // Размер вектора одного токена
|
||||
|
||||
const int MIDDLE_LAYER = 128;
|
||||
|
||||
const int MAX_VOCAB = 270; // Размер словаря
|
||||
|
||||
typedef enum { SIGMOID } FunctionActivate_t;
|
||||
typedef struct { int size; FunctionActivate_t activate; } LayerStructure_t;
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user