Addestramento per riconoscere numeri scritti a mano
This commit is contained in:
parent
1e648fe436
commit
e3f8ee037a
BIN
data/t10k-images-idx3-ubyte
Normal file
BIN
data/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/t10k-labels-idx1-ubyte
Normal file
BIN
data/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/train-images-idx3-ubyte
Normal file
BIN
data/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/train-labels-idx1-ubyte
Normal file
BIN
data/train-labels-idx1-ubyte
Normal file
Binary file not shown.
38
get_data.py
Normal file
38
get_data.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
import urllib.request
|
||||||
|
import gzip
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
files = [
|
||||||
|
"train-images-idx3-ubyte.gz",
|
||||||
|
"train-labels-idx1-ubyte.gz",
|
||||||
|
"t10k-images-idx3-ubyte.gz",
|
||||||
|
"t10k-labels-idx1-ubyte.gz"
|
||||||
|
]
|
||||||
|
|
||||||
|
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||||
|
data_dir = "data"
|
||||||
|
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
os.makedirs(data_dir)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(data_dir, file)
|
||||||
|
unzipped_path = file_path.replace(".gz", "")
|
||||||
|
|
||||||
|
if not os.path.exists(unzipped_path):
|
||||||
|
print(f"Scaricando {file}...")
|
||||||
|
urllib.request.urlretrieve(base_url + file, file_path)
|
||||||
|
|
||||||
|
print(f"Estraendo {file}...")
|
||||||
|
with gzip.open(file_path, 'rb') as f_in:
|
||||||
|
with open(unzipped_path, 'wb') as f_out:
|
||||||
|
shutil.copyfileobj(f_in, f_out)
|
||||||
|
|
||||||
|
# Rimuoviamo il file .gz per pulizia
|
||||||
|
os.remove(file_path)
|
||||||
|
print("Fatto.")
|
||||||
|
else:
|
||||||
|
print(f"{unzipped_path} esiste già.")
|
||||||
|
|
||||||
|
print("\nTutti i dati sono nella cartella 'data/'!")
|
||||||
File diff suppressed because one or more lines are too long
82
src/main.zig
82
src/main.zig
|
|
@ -1,67 +1,71 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const Network = @import("modular_network.zig").Network;
|
const Network = @import("modular_network.zig").Network;
|
||||||
|
const MnistData = @import("mnist.zig").MnistData;
|
||||||
|
|
||||||
pub fn main() !void {
|
pub fn main() !void {
|
||||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||||
const allocator = gpa.allocator();
|
const allocator = gpa.allocator();
|
||||||
defer _ = gpa.deinit();
|
defer _ = gpa.deinit();
|
||||||
|
|
||||||
|
std.debug.print("--- CARICAMENTO MNIST ---\n", .{});
|
||||||
|
|
||||||
|
// Carichiamo solo 1000 immagini per iniziare (per vedere se funziona veloce)
|
||||||
|
// I file sono nella cartella "data/"
|
||||||
|
var dataset = try MnistData.init(allocator, "data/train-images-idx3-ubyte", "data/train-labels-idx1-ubyte", 2000);
|
||||||
|
defer dataset.deinit();
|
||||||
|
|
||||||
|
std.debug.print("Caricate {d} immagini.\n", .{dataset.images.len});
|
||||||
|
|
||||||
|
// --- ARCHITETTURA RETE ---
|
||||||
var net = Network.init(allocator);
|
var net = Network.init(allocator);
|
||||||
defer net.deinit();
|
defer net.deinit();
|
||||||
|
|
||||||
// --- ARCHITETTURA ---
|
// 784 Input -> 64 Hidden -> 32 Hidden -> 10 Output
|
||||||
// Input(2) -> Hidden(8) -> Hidden(8) -> Hidden(4) -> Output(1)
|
try net.addLayer(784, 64, 111);
|
||||||
try net.addLayer(2, 8, 123);
|
try net.addLayer(64, 32, 222);
|
||||||
try net.addLayer(8, 8, 456);
|
try net.addLayer(32, 10, 333);
|
||||||
try net.addLayer(8, 4, 789);
|
|
||||||
try net.addLayer(4, 1, 101);
|
|
||||||
|
|
||||||
net.printTopology();
|
net.printTopology();
|
||||||
|
|
||||||
// Dati XOR
|
std.debug.print("--- INIZIO TRAINING MNIST ---\n", .{});
|
||||||
const inputs = [_][]const f32{ &.{ 0.0, 0.0 }, &.{ 0.0, 1.0 }, &.{ 1.0, 0.0 }, &.{ 1.0, 1.0 } };
|
|
||||||
const targets = [_][]const f32{ &.{0.0}, &.{1.0}, &.{1.0}, &.{0.0} };
|
|
||||||
|
|
||||||
std.debug.print("--- TRAINING DEEP CON VISUALIZER E DEBUG --- \n", .{});
|
const lr: f32 = 0.1;
|
||||||
|
const epochs = 50; // Meno epoche, ma ogni epoca elabora 2000 immagini!
|
||||||
// --- CONFIGURAZIONE ---
|
|
||||||
const lr: f32 = 0.2;
|
|
||||||
const max_epochs = 50000;
|
|
||||||
|
|
||||||
const slow_mode = true; // Attiva il rallentatore
|
|
||||||
const export_step = 100; // Ogni quante epoche aggiorniamo
|
|
||||||
const delay_ms = 25; // Ritardo in millisecondi
|
|
||||||
|
|
||||||
var epoch: usize = 0;
|
var epoch: usize = 0;
|
||||||
while (epoch <= max_epochs) : (epoch += 1) {
|
while (epoch < epochs) : (epoch += 1) {
|
||||||
var total_loss: f32 = 0.0;
|
var total_loss: f32 = 0.0;
|
||||||
|
var correct: usize = 0;
|
||||||
|
|
||||||
// Training step
|
for (dataset.images, 0..) |img, i| {
|
||||||
for (0..4) |i| {
|
// Training
|
||||||
total_loss += try net.train(inputs[i], targets[i], lr);
|
total_loss += try net.train(img, dataset.labels[i], lr);
|
||||||
|
|
||||||
|
// Calcolo precisione (Accuracy) al volo
|
||||||
|
const out = net.forward(img);
|
||||||
|
if (argmax(out) == argmax(dataset.labels[i])) {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- ZONA OUTPUT E EXPORT ---
|
const accuracy = @as(f32, @floatFromInt(correct)) / @as(f32, @floatFromInt(dataset.images.len)) * 100.0;
|
||||||
if (epoch % export_step == 0) {
|
|
||||||
|
|
||||||
// 1. Stampiamo HEADER con Epoca e Loss
|
std.debug.print("Epoca {d}: Loss {d:.4} | Accuracy: {d:.2}%\n", .{ epoch, total_loss / @as(f32, @floatFromInt(dataset.images.len)), accuracy });
|
||||||
std.debug.print("\n=== EPOCA {d} | Loss: {d:.6} ===\n", .{ epoch, total_loss });
|
|
||||||
|
|
||||||
// 2. Stampiamo le PREVISIONI attuali per i 4 casi
|
// Salviamo lo stato per il visualizer (vedrai un "cervello" molto complesso!)
|
||||||
for (inputs) |inp| {
|
|
||||||
const out = net.forward(inp);
|
|
||||||
// Stampa formattata: Input -> Output
|
|
||||||
std.debug.print("In: [{d:.0}, {d:.0}] -> Out: {d:.4}\n", .{ inp[0], inp[1], out[0] });
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Esportiamo il JSON per il browser
|
|
||||||
try net.exportJSON("network_state.json", epoch, total_loss);
|
try net.exportJSON("network_state.json", epoch, total_loss);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 4. Delay per l'animazione
|
// Funzione helper per trovare l'indice del valore più alto (es: quale numero è?)
|
||||||
if (slow_mode) {
|
fn argmax(slice: []const f32) usize {
|
||||||
// Ricorda: std.Thread.sleep vuole nanosecondi
|
var max_val: f32 = -1000.0;
|
||||||
std.Thread.sleep(delay_ms * 1_000_000);
|
var max_idx: usize = 0;
|
||||||
}
|
for (slice, 0..) |val, i| {
|
||||||
|
if (val > max_val) {
|
||||||
|
max_val = val;
|
||||||
|
max_idx = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return max_idx;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
68
src/mnist.zig
Normal file
68
src/mnist.zig
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
const std = @import("std");
|
||||||
|
const Allocator = std.mem.Allocator;
|
||||||
|
|
||||||
|
pub const MnistData = struct {
|
||||||
|
images: [][]f32,
|
||||||
|
labels: [][]f32,
|
||||||
|
allocator: Allocator,
|
||||||
|
|
||||||
|
pub fn init(allocator: Allocator, img_path: []const u8, lbl_path: []const u8, max_items: usize) !MnistData {
|
||||||
|
// --- 1. CARICAMENTO IMMAGINI ---
|
||||||
|
const img_file = try std.fs.cwd().openFile(img_path, .{});
|
||||||
|
defer img_file.close();
|
||||||
|
try img_file.seekTo(16);
|
||||||
|
|
||||||
|
// --- 2. CARICAMENTO ETICHETTE ---
|
||||||
|
const lbl_file = try std.fs.cwd().openFile(lbl_path, .{});
|
||||||
|
defer lbl_file.close();
|
||||||
|
try lbl_file.seekTo(8);
|
||||||
|
|
||||||
|
// --- 3. ALLOCAZIONE ---
|
||||||
|
var images: std.ArrayList([]f32) = .{};
|
||||||
|
defer images.deinit(allocator);
|
||||||
|
var labels: std.ArrayList([]f32) = .{};
|
||||||
|
defer labels.deinit(allocator);
|
||||||
|
|
||||||
|
const img_size = 28 * 28;
|
||||||
|
var buffer: [784]u8 = undefined;
|
||||||
|
var label_byte: [1]u8 = undefined;
|
||||||
|
var i: usize = 0;
|
||||||
|
|
||||||
|
while (i < max_items) : (i += 1) {
|
||||||
|
// Leggi direttamente dal file
|
||||||
|
const bytes_read = try img_file.read(&buffer);
|
||||||
|
if (bytes_read < img_size) break;
|
||||||
|
|
||||||
|
// Leggi 1 byte per l'etichetta
|
||||||
|
const lbl_read = try lbl_file.read(&label_byte);
|
||||||
|
if (lbl_read < 1) break;
|
||||||
|
|
||||||
|
const img_float = try allocator.alloc(f32, img_size);
|
||||||
|
for (buffer, 0..) |b, px| {
|
||||||
|
img_float[px] = @as(f32, @floatFromInt(b)) / 255.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const lbl_float = try allocator.alloc(f32, 10);
|
||||||
|
@memset(lbl_float, 0.0);
|
||||||
|
lbl_float[label_byte[0]] = 1.0;
|
||||||
|
|
||||||
|
try images.append(allocator, img_float);
|
||||||
|
try labels.append(allocator, lbl_float);
|
||||||
|
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return MnistData{
|
||||||
|
.images = try images.toOwnedSlice(allocator),
|
||||||
|
.labels = try labels.toOwnedSlice(allocator),
|
||||||
|
.allocator = allocator,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *MnistData) void {
|
||||||
|
for (self.images) |img| self.allocator.free(img);
|
||||||
|
for (self.labels) |lbl| self.allocator.free(lbl);
|
||||||
|
self.allocator.free(self.images);
|
||||||
|
self.allocator.free(self.labels);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -122,8 +122,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aggiorna ogni 200ms
|
setInterval(update, 20);
|
||||||
setInterval(update, 50);
|
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
Loading…
Reference in a new issue