Addestramento per riconoscere numeri scritti a mano
This commit is contained in:
parent
1e648fe436
commit
e3f8ee037a
BIN
data/t10k-images-idx3-ubyte
Normal file
BIN
data/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/t10k-labels-idx1-ubyte
Normal file
BIN
data/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/train-images-idx3-ubyte
Normal file
BIN
data/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/train-labels-idx1-ubyte
Normal file
BIN
data/train-labels-idx1-ubyte
Normal file
Binary file not shown.
38
get_data.py
Normal file
38
get_data.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import urllib.request
|
||||
import gzip
|
||||
import os
|
||||
import shutil
|
||||
|
||||
files = [
|
||||
"train-images-idx3-ubyte.gz",
|
||||
"train-labels-idx1-ubyte.gz",
|
||||
"t10k-images-idx3-ubyte.gz",
|
||||
"t10k-labels-idx1-ubyte.gz"
|
||||
]
|
||||
|
||||
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||
data_dir = "data"
|
||||
|
||||
if not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
|
||||
for file in files:
|
||||
file_path = os.path.join(data_dir, file)
|
||||
unzipped_path = file_path.replace(".gz", "")
|
||||
|
||||
if not os.path.exists(unzipped_path):
|
||||
print(f"Scaricando {file}...")
|
||||
urllib.request.urlretrieve(base_url + file, file_path)
|
||||
|
||||
print(f"Estraendo {file}...")
|
||||
with gzip.open(file_path, 'rb') as f_in:
|
||||
with open(unzipped_path, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
||||
# Rimuoviamo il file .gz per pulizia
|
||||
os.remove(file_path)
|
||||
print("Fatto.")
|
||||
else:
|
||||
print(f"{unzipped_path} esiste già.")
|
||||
|
||||
print("\nTutti i dati sono nella cartella 'data/'!")
|
||||
File diff suppressed because one or more lines are too long
86
src/main.zig
86
src/main.zig
|
|
@ -1,67 +1,71 @@
|
|||
const std = @import("std");
|
||||
const Network = @import("modular_network.zig").Network;
|
||||
const MnistData = @import("mnist.zig").MnistData;
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
const allocator = gpa.allocator();
|
||||
defer _ = gpa.deinit();
|
||||
|
||||
std.debug.print("--- CARICAMENTO MNIST ---\n", .{});
|
||||
|
||||
// Carichiamo solo 1000 immagini per iniziare (per vedere se funziona veloce)
|
||||
// I file sono nella cartella "data/"
|
||||
var dataset = try MnistData.init(allocator, "data/train-images-idx3-ubyte", "data/train-labels-idx1-ubyte", 2000);
|
||||
defer dataset.deinit();
|
||||
|
||||
std.debug.print("Caricate {d} immagini.\n", .{dataset.images.len});
|
||||
|
||||
// --- ARCHITETTURA RETE ---
|
||||
var net = Network.init(allocator);
|
||||
defer net.deinit();
|
||||
|
||||
// --- ARCHITETTURA ---
|
||||
// Input(2) -> Hidden(8) -> Hidden(8) -> Hidden(4) -> Output(1)
|
||||
try net.addLayer(2, 8, 123);
|
||||
try net.addLayer(8, 8, 456);
|
||||
try net.addLayer(8, 4, 789);
|
||||
try net.addLayer(4, 1, 101);
|
||||
// 784 Input -> 64 Hidden -> 32 Hidden -> 10 Output
|
||||
try net.addLayer(784, 64, 111);
|
||||
try net.addLayer(64, 32, 222);
|
||||
try net.addLayer(32, 10, 333);
|
||||
|
||||
net.printTopology();
|
||||
|
||||
// Dati XOR
|
||||
const inputs = [_][]const f32{ &.{ 0.0, 0.0 }, &.{ 0.0, 1.0 }, &.{ 1.0, 0.0 }, &.{ 1.0, 1.0 } };
|
||||
const targets = [_][]const f32{ &.{0.0}, &.{1.0}, &.{1.0}, &.{0.0} };
|
||||
std.debug.print("--- INIZIO TRAINING MNIST ---\n", .{});
|
||||
|
||||
std.debug.print("--- TRAINING DEEP CON VISUALIZER E DEBUG --- \n", .{});
|
||||
|
||||
// --- CONFIGURAZIONE ---
|
||||
const lr: f32 = 0.2;
|
||||
const max_epochs = 50000;
|
||||
|
||||
const slow_mode = true; // Attiva il rallentatore
|
||||
const export_step = 100; // Ogni quante epoche aggiorniamo
|
||||
const delay_ms = 25; // Ritardo in millisecondi
|
||||
const lr: f32 = 0.1;
|
||||
const epochs = 50; // Meno epoche, ma ogni epoca elabora 2000 immagini!
|
||||
|
||||
var epoch: usize = 0;
|
||||
while (epoch <= max_epochs) : (epoch += 1) {
|
||||
while (epoch < epochs) : (epoch += 1) {
|
||||
var total_loss: f32 = 0.0;
|
||||
var correct: usize = 0;
|
||||
|
||||
// Training step
|
||||
for (0..4) |i| {
|
||||
total_loss += try net.train(inputs[i], targets[i], lr);
|
||||
for (dataset.images, 0..) |img, i| {
|
||||
// Training
|
||||
total_loss += try net.train(img, dataset.labels[i], lr);
|
||||
|
||||
// Calcolo precisione (Accuracy) al volo
|
||||
const out = net.forward(img);
|
||||
if (argmax(out) == argmax(dataset.labels[i])) {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// --- ZONA OUTPUT E EXPORT ---
|
||||
if (epoch % export_step == 0) {
|
||||
const accuracy = @as(f32, @floatFromInt(correct)) / @as(f32, @floatFromInt(dataset.images.len)) * 100.0;
|
||||
|
||||
// 1. Stampiamo HEADER con Epoca e Loss
|
||||
std.debug.print("\n=== EPOCA {d} | Loss: {d:.6} ===\n", .{ epoch, total_loss });
|
||||
std.debug.print("Epoca {d}: Loss {d:.4} | Accuracy: {d:.2}%\n", .{ epoch, total_loss / @as(f32, @floatFromInt(dataset.images.len)), accuracy });
|
||||
|
||||
// 2. Stampiamo le PREVISIONI attuali per i 4 casi
|
||||
for (inputs) |inp| {
|
||||
const out = net.forward(inp);
|
||||
// Stampa formattata: Input -> Output
|
||||
std.debug.print("In: [{d:.0}, {d:.0}] -> Out: {d:.4}\n", .{ inp[0], inp[1], out[0] });
|
||||
}
|
||||
|
||||
// 3. Esportiamo il JSON per il browser
|
||||
// Salviamo lo stato per il visualizer (vedrai un "cervello" molto complesso!)
|
||||
try net.exportJSON("network_state.json", epoch, total_loss);
|
||||
|
||||
// 4. Delay per l'animazione
|
||||
if (slow_mode) {
|
||||
// Ricorda: std.Thread.sleep vuole nanosecondi
|
||||
std.Thread.sleep(delay_ms * 1_000_000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Funzione helper per trovare l'indice del valore più alto (es: quale numero è?)
|
||||
fn argmax(slice: []const f32) usize {
|
||||
var max_val: f32 = -1000.0;
|
||||
var max_idx: usize = 0;
|
||||
for (slice, 0..) |val, i| {
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
return max_idx;
|
||||
}
|
||||
|
|
|
|||
68
src/mnist.zig
Normal file
68
src/mnist.zig
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
const std = @import("std");
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
pub const MnistData = struct {
|
||||
images: [][]f32,
|
||||
labels: [][]f32,
|
||||
allocator: Allocator,
|
||||
|
||||
pub fn init(allocator: Allocator, img_path: []const u8, lbl_path: []const u8, max_items: usize) !MnistData {
|
||||
// --- 1. CARICAMENTO IMMAGINI ---
|
||||
const img_file = try std.fs.cwd().openFile(img_path, .{});
|
||||
defer img_file.close();
|
||||
try img_file.seekTo(16);
|
||||
|
||||
// --- 2. CARICAMENTO ETICHETTE ---
|
||||
const lbl_file = try std.fs.cwd().openFile(lbl_path, .{});
|
||||
defer lbl_file.close();
|
||||
try lbl_file.seekTo(8);
|
||||
|
||||
// --- 3. ALLOCAZIONE ---
|
||||
var images: std.ArrayList([]f32) = .{};
|
||||
defer images.deinit(allocator);
|
||||
var labels: std.ArrayList([]f32) = .{};
|
||||
defer labels.deinit(allocator);
|
||||
|
||||
const img_size = 28 * 28;
|
||||
var buffer: [784]u8 = undefined;
|
||||
var label_byte: [1]u8 = undefined;
|
||||
var i: usize = 0;
|
||||
|
||||
while (i < max_items) : (i += 1) {
|
||||
// Leggi direttamente dal file
|
||||
const bytes_read = try img_file.read(&buffer);
|
||||
if (bytes_read < img_size) break;
|
||||
|
||||
// Leggi 1 byte per l'etichetta
|
||||
const lbl_read = try lbl_file.read(&label_byte);
|
||||
if (lbl_read < 1) break;
|
||||
|
||||
const img_float = try allocator.alloc(f32, img_size);
|
||||
for (buffer, 0..) |b, px| {
|
||||
img_float[px] = @as(f32, @floatFromInt(b)) / 255.0;
|
||||
}
|
||||
|
||||
const lbl_float = try allocator.alloc(f32, 10);
|
||||
@memset(lbl_float, 0.0);
|
||||
lbl_float[label_byte[0]] = 1.0;
|
||||
|
||||
try images.append(allocator, img_float);
|
||||
try labels.append(allocator, lbl_float);
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
return MnistData{
|
||||
.images = try images.toOwnedSlice(allocator),
|
||||
.labels = try labels.toOwnedSlice(allocator),
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *MnistData) void {
|
||||
for (self.images) |img| self.allocator.free(img);
|
||||
for (self.labels) |lbl| self.allocator.free(lbl);
|
||||
self.allocator.free(self.images);
|
||||
self.allocator.free(self.labels);
|
||||
}
|
||||
};
|
||||
|
|
@ -122,8 +122,7 @@
|
|||
}
|
||||
}
|
||||
|
||||
// Aggiorna ogni 200ms
|
||||
setInterval(update, 50);
|
||||
setInterval(update, 20);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Loading…
Reference in a new issue