Addestramento per riconoscere numeri scritti a mano

This commit is contained in:
Riccardo Forese 2026-02-03 11:33:19 +01:00
parent 1e648fe436
commit e3f8ee037a
9 changed files with 165 additions and 62 deletions

BIN
data/t10k-images-idx3-ubyte Normal file

Binary file not shown.

BIN
data/t10k-labels-idx1-ubyte Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

38
get_data.py Normal file
View file

@ -0,0 +1,38 @@
import urllib.request
import gzip
import os
import shutil
files = [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz"
]
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
data_dir = "data"
if not os.path.exists(data_dir):
os.makedirs(data_dir)
for file in files:
file_path = os.path.join(data_dir, file)
unzipped_path = file_path.replace(".gz", "")
if not os.path.exists(unzipped_path):
print(f"Scaricando {file}...")
urllib.request.urlretrieve(base_url + file, file_path)
print(f"Estraendo {file}...")
with gzip.open(file_path, 'rb') as f_in:
with open(unzipped_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
# Rimuoviamo il file .gz per pulizia
os.remove(file_path)
print("Fatto.")
else:
print(f"{unzipped_path} esiste già.")
print("\nTutti i dati sono nella cartella 'data/'!")

File diff suppressed because one or more lines are too long

View file

@ -1,67 +1,71 @@
const std = @import("std"); const std = @import("std");
const Network = @import("modular_network.zig").Network; const Network = @import("modular_network.zig").Network;
const MnistData = @import("mnist.zig").MnistData;
pub fn main() !void { pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator(); const allocator = gpa.allocator();
defer _ = gpa.deinit(); defer _ = gpa.deinit();
std.debug.print("--- CARICAMENTO MNIST ---\n", .{});
// Carichiamo solo 1000 immagini per iniziare (per vedere se funziona veloce)
// I file sono nella cartella "data/"
var dataset = try MnistData.init(allocator, "data/train-images-idx3-ubyte", "data/train-labels-idx1-ubyte", 2000);
defer dataset.deinit();
std.debug.print("Caricate {d} immagini.\n", .{dataset.images.len});
// --- ARCHITETTURA RETE ---
var net = Network.init(allocator); var net = Network.init(allocator);
defer net.deinit(); defer net.deinit();
// --- ARCHITETTURA --- // 784 Input -> 64 Hidden -> 32 Hidden -> 10 Output
// Input(2) -> Hidden(8) -> Hidden(8) -> Hidden(4) -> Output(1) try net.addLayer(784, 64, 111);
try net.addLayer(2, 8, 123); try net.addLayer(64, 32, 222);
try net.addLayer(8, 8, 456); try net.addLayer(32, 10, 333);
try net.addLayer(8, 4, 789);
try net.addLayer(4, 1, 101);
net.printTopology(); net.printTopology();
// Dati XOR std.debug.print("--- INIZIO TRAINING MNIST ---\n", .{});
const inputs = [_][]const f32{ &.{ 0.0, 0.0 }, &.{ 0.0, 1.0 }, &.{ 1.0, 0.0 }, &.{ 1.0, 1.0 } };
const targets = [_][]const f32{ &.{0.0}, &.{1.0}, &.{1.0}, &.{0.0} };
std.debug.print("--- TRAINING DEEP CON VISUALIZER E DEBUG --- \n", .{}); const lr: f32 = 0.1;
const epochs = 50; // Meno epoche, ma ogni epoca elabora 2000 immagini!
// --- CONFIGURAZIONE ---
const lr: f32 = 0.2;
const max_epochs = 50000;
const slow_mode = true; // Attiva il rallentatore
const export_step = 100; // Ogni quante epoche aggiorniamo
const delay_ms = 25; // Ritardo in millisecondi
var epoch: usize = 0; var epoch: usize = 0;
while (epoch <= max_epochs) : (epoch += 1) { while (epoch < epochs) : (epoch += 1) {
var total_loss: f32 = 0.0; var total_loss: f32 = 0.0;
var correct: usize = 0;
// Training step for (dataset.images, 0..) |img, i| {
for (0..4) |i| { // Training
total_loss += try net.train(inputs[i], targets[i], lr); total_loss += try net.train(img, dataset.labels[i], lr);
// Calcolo precisione (Accuracy) al volo
const out = net.forward(img);
if (argmax(out) == argmax(dataset.labels[i])) {
correct += 1;
}
} }
// --- ZONA OUTPUT E EXPORT --- const accuracy = @as(f32, @floatFromInt(correct)) / @as(f32, @floatFromInt(dataset.images.len)) * 100.0;
if (epoch % export_step == 0) {
// 1. Stampiamo HEADER con Epoca e Loss std.debug.print("Epoca {d}: Loss {d:.4} | Accuracy: {d:.2}%\n", .{ epoch, total_loss / @as(f32, @floatFromInt(dataset.images.len)), accuracy });
std.debug.print("\n=== EPOCA {d} | Loss: {d:.6} ===\n", .{ epoch, total_loss });
// 2. Stampiamo le PREVISIONI attuali per i 4 casi // Salviamo lo stato per il visualizer (vedrai un "cervello" molto complesso!)
for (inputs) |inp| {
const out = net.forward(inp);
// Stampa formattata: Input -> Output
std.debug.print("In: [{d:.0}, {d:.0}] -> Out: {d:.4}\n", .{ inp[0], inp[1], out[0] });
}
// 3. Esportiamo il JSON per il browser
try net.exportJSON("network_state.json", epoch, total_loss); try net.exportJSON("network_state.json", epoch, total_loss);
}
}
// 4. Delay per l'animazione // Funzione helper per trovare l'indice del valore più alto (es: quale numero è?)
if (slow_mode) { fn argmax(slice: []const f32) usize {
// Ricorda: std.Thread.sleep vuole nanosecondi var max_val: f32 = -1000.0;
std.Thread.sleep(delay_ms * 1_000_000); var max_idx: usize = 0;
} for (slice, 0..) |val, i| {
if (val > max_val) {
max_val = val;
max_idx = i;
} }
} }
return max_idx;
} }

68
src/mnist.zig Normal file
View file

@ -0,0 +1,68 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
pub const MnistData = struct {
images: [][]f32,
labels: [][]f32,
allocator: Allocator,
pub fn init(allocator: Allocator, img_path: []const u8, lbl_path: []const u8, max_items: usize) !MnistData {
// --- 1. CARICAMENTO IMMAGINI ---
const img_file = try std.fs.cwd().openFile(img_path, .{});
defer img_file.close();
try img_file.seekTo(16);
// --- 2. CARICAMENTO ETICHETTE ---
const lbl_file = try std.fs.cwd().openFile(lbl_path, .{});
defer lbl_file.close();
try lbl_file.seekTo(8);
// --- 3. ALLOCAZIONE ---
var images: std.ArrayList([]f32) = .{};
defer images.deinit(allocator);
var labels: std.ArrayList([]f32) = .{};
defer labels.deinit(allocator);
const img_size = 28 * 28;
var buffer: [784]u8 = undefined;
var label_byte: [1]u8 = undefined;
var i: usize = 0;
while (i < max_items) : (i += 1) {
// Leggi direttamente dal file
const bytes_read = try img_file.read(&buffer);
if (bytes_read < img_size) break;
// Leggi 1 byte per l'etichetta
const lbl_read = try lbl_file.read(&label_byte);
if (lbl_read < 1) break;
const img_float = try allocator.alloc(f32, img_size);
for (buffer, 0..) |b, px| {
img_float[px] = @as(f32, @floatFromInt(b)) / 255.0;
}
const lbl_float = try allocator.alloc(f32, 10);
@memset(lbl_float, 0.0);
lbl_float[label_byte[0]] = 1.0;
try images.append(allocator, img_float);
try labels.append(allocator, lbl_float);
i += 1;
}
return MnistData{
.images = try images.toOwnedSlice(allocator),
.labels = try labels.toOwnedSlice(allocator),
.allocator = allocator,
};
}
pub fn deinit(self: *MnistData) void {
for (self.images) |img| self.allocator.free(img);
for (self.labels) |lbl| self.allocator.free(lbl);
self.allocator.free(self.images);
self.allocator.free(self.labels);
}
};

View file

@ -122,8 +122,7 @@
} }
} }
// Aggiorna ogni 200ms setInterval(update, 20);
setInterval(update, 50);
</script> </script>
</body> </body>
</html> </html>