70 lines
2.3 KiB
Zig
70 lines
2.3 KiB
Zig
const std = @import("std");
|
|
const Neuron = @import("neuron.zig").Neuron;
|
|
const Tensor = @import("tensor.zig").Tensor;
|
|
|
|
pub fn main() !void {
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
|
const allocator = gpa.allocator();
|
|
defer _ = gpa.deinit();
|
|
|
|
// 1. Inizializziamo il neurone (2 input perché la porta AND ha 2 ingressi)
|
|
var my_neuron = try Neuron.init(allocator, 2);
|
|
defer my_neuron.deinit();
|
|
|
|
// 2. Prepariamo il Dataset (AND Gate)
|
|
// Creiamo un tensore riutilizzabile per gli input
|
|
var input_tensor = try Tensor.init(allocator, &[_]usize{2});
|
|
defer input_tensor.deinit();
|
|
|
|
// I 4 casi possibili (Training Data)
|
|
const training_data = [_][2]f32{
|
|
.{ 0.0, 0.0 },
|
|
.{ 0.0, 1.0 },
|
|
.{ 1.0, 0.0 },
|
|
.{ 1.0, 1.0 },
|
|
};
|
|
|
|
// Le 4 risposte corrette (Labels)
|
|
const targets = [_]f32{ 0.0, 0.0, 0.0, 1.0 };
|
|
|
|
const lr: f32 = 0.1; // Learning rate un po' più aggressivo
|
|
|
|
std.debug.print("--- INIZIO TRAINING (AND GATE) ---\n", .{});
|
|
|
|
// 3. Ciclo di Training
|
|
var epoch: usize = 0;
|
|
while (epoch < 2000) : (epoch += 1) { // 2000 Epoche
|
|
var total_error: f32 = 0.0;
|
|
|
|
// Per ogni epoca, passiamo attraverso TUTTI gli esempi
|
|
for (training_data, 0..) |data, index| {
|
|
// Carichiamo i dati nel tensore
|
|
input_tensor.data[0] = data[0];
|
|
input_tensor.data[1] = data[1];
|
|
|
|
// Train su questo specifico esempio
|
|
const loss = my_neuron.train(input_tensor, targets[index], lr);
|
|
total_error += loss;
|
|
}
|
|
|
|
// Stampiamo ogni 200 epoche
|
|
if (epoch % 200 == 0) {
|
|
std.debug.print("Epoca {d}: Errore Medio = {d:.6}\n", .{ epoch, total_error / 4.0 });
|
|
}
|
|
}
|
|
|
|
std.debug.print("\n--- TEST FINALE ---\n", .{});
|
|
|
|
// Verifichiamo cosa ha imparato
|
|
for (training_data) |data| {
|
|
input_tensor.data[0] = data[0];
|
|
input_tensor.data[1] = data[1];
|
|
const prediction = my_neuron.forward(input_tensor);
|
|
|
|
// Arrotondiamo visivamente per capire se è 0 o 1
|
|
const result_bool: u8 = if (prediction > 0.5) 1 else 0;
|
|
|
|
std.debug.print("Input: {d:.1}, {d:.1} -> Predizione: {d:.4} (Interpretato: {d})\n", .{ data[0], data[1], prediction, result_bool });
|
|
}
|
|
}
|