New matrix representation. Add col and row iterators to Zig matrix example.

This commit is contained in:
Maciej Krzyżanowski 2024-08-25 23:58:46 +02:00
parent e72623880d
commit e1883ac2e5

View File

@ -1,23 +1,70 @@
const std = @import("std"); const std = @import("std");
const MatrixOpError = error{ const MatrixOpError = error{
IncompatibleSizes, IncompatibleDims,
OutOfBounds, OutOfBounds,
}; };
const Matrix = struct { const Matrix = struct {
allocator: std.mem.Allocator, const RowIterator = struct {
numbers: []i64, matrix: Matrix,
dimensions: []const u64, rowIdx: u64,
nextColIdx: u64,
pub fn init(allocator: std.mem.Allocator, dimensions: []const u64, init_numbers: []const i64) !Matrix { pub fn init(matrix: Matrix, rowIdx: u64) RowIterator {
var size: u64 = 1; return RowIterator{
.matrix = matrix,
for (dimensions) |dim| { .rowIdx = rowIdx,
size *= dim; .nextColIdx = 0,
};
} }
var numbers = try allocator.alloc(i64, size); pub fn next(self: *RowIterator) ?i64 {
if (self.nextColIdx == self.matrix.cols) {
return null;
}
const nextValue = self.matrix.get(.{ self.rowIdx, self.nextColIdx }) catch unreachable;
self.nextColIdx += 1;
return nextValue;
}
};
const ColIterator = struct {
matrix: Matrix,
colIdx: u64,
nextRowIdx: u64,
pub fn init(matrix: Matrix, rowIdx: u64) ColIterator {
return ColIterator{
.matrix = matrix,
.colIdx = rowIdx,
.nextRowIdx = 0,
};
}
pub fn next(self: *ColIterator) ?i64 {
if (self.nextRowIdx == self.matrix.rows) {
return null;
}
const nextValue = self.matrix.get(.{ self.nextRowIdx, self.colIdx }) catch unreachable;
self.nextRowIdx += 1;
return nextValue;
}
};
allocator: std.mem.Allocator,
numbers: []i64,
rows: u64,
cols: u64,
pub fn init(allocator: std.mem.Allocator, rows: u64, cols: u64, init_numbers: []const i64) !Matrix {
var numbers = try allocator.alloc(i64, rows * cols);
for (0.., init_numbers) |i, num| { for (0.., init_numbers) |i, num| {
numbers[i] = num; numbers[i] = num;
@ -26,23 +73,16 @@ const Matrix = struct {
return Matrix{ return Matrix{
.allocator = allocator, .allocator = allocator,
.numbers = numbers, .numbers = numbers,
.dimensions = dimensions, .rows = rows,
.cols = cols,
}; };
} }
fn position_deep_to_flat(self: Matrix, position: []const u64) u64 { fn position_deep_to_flat(self: Matrix, position: [2]u64) u64 {
var flat_position: u64 = 0; return position[0] * self.cols + position[1];
var curr_volume: u64 = 1;
for (0.., position) |i, pos| {
flat_position += pos * curr_volume;
curr_volume *= self.dimensions[i];
} }
return flat_position; pub fn get(self: Matrix, position: [2]u64) !i64 {
}
pub fn get(self: Matrix, position: []const u64) !i64 {
const flat_position = position_deep_to_flat(self, position); const flat_position = position_deep_to_flat(self, position);
if (flat_position >= self.numbers.len) { if (flat_position >= self.numbers.len) {
@ -52,7 +92,7 @@ const Matrix = struct {
return self.numbers[flat_position]; return self.numbers[flat_position];
} }
pub fn set(self: Matrix, position: []const u64, value: i64) !void { pub fn set(self: Matrix, position: [2]u64, value: i64) !void {
const flat_position = position_deep_to_flat(self, position); const flat_position = position_deep_to_flat(self, position);
if (flat_position >= self.numbers.len) { if (flat_position >= self.numbers.len) {
@ -71,17 +111,18 @@ const Matrix = struct {
} }
pub fn add(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix { pub fn add(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix {
if (!std.mem.eql(u64, self.dimensions, other.dimensions)) { if (self.rows != other.rows or self.cols != other.cols) {
return error.IncompatibleSizes; return error.IncompatibleSizes;
} }
var result: Matrix = try Matrix.init( var result: Matrix = try Matrix.init(
allocator, allocator,
self.dimensions, self.rows,
self.cols,
self.numbers, self.numbers,
); );
for (other.numbers, 0..) |num, idx| { for (0.., other.numbers) |idx, num| {
result.numbers[idx] += num; result.numbers[idx] += num;
} }
@ -89,17 +130,18 @@ const Matrix = struct {
} }
pub fn sub(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix { pub fn sub(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix {
if (!std.mem.eql(u64, self.dimensions, other.dimensions)) { if (self.rows != other.rows or self.cols != other.cols) {
return error.IncompatibleSizes; return error.IncompatibleSizes;
} }
var result: Matrix = try Matrix.init( var result: Matrix = try Matrix.init(
allocator, allocator,
self.dimensions, self.rows,
self.cols,
self.numbers, self.numbers,
); );
for (other.numbers, 0..) |num, idx| { for (0.., other.numbers) |idx, num| {
result.numbers[idx] -= num; result.numbers[idx] -= num;
} }
@ -109,7 +151,8 @@ const Matrix = struct {
pub fn neg(self: Matrix, allocator: std.mem.Allocator) !Matrix { pub fn neg(self: Matrix, allocator: std.mem.Allocator) !Matrix {
var result: Matrix = try Matrix.init( var result: Matrix = try Matrix.init(
allocator, allocator,
self.dimensions, self.rows,
self.cols,
self.numbers, self.numbers,
); );
@ -120,6 +163,20 @@ const Matrix = struct {
return result; return result;
} }
pub fn row(self: Matrix, rowIdx: u64) RowIterator {
return RowIterator.init(
self,
rowIdx,
);
}
pub fn col(self: Matrix, colIdx: u64) ColIterator {
return ColIterator.init(
self,
colIdx,
);
}
pub fn deinit(self: Matrix) void { pub fn deinit(self: Matrix) void {
self.allocator.free(self.numbers); self.allocator.free(self.numbers);
} }
@ -136,51 +193,34 @@ pub fn main() void {
} }
test "matrix init" { test "matrix init" {
const mat = try Matrix.init(std.testing.allocator, &.{ 5, 5 }, &.{}); const mat = try Matrix.init(std.testing.allocator, 5, 5, &.{});
defer mat.deinit(); defer mat.deinit();
try std.testing.expect(mat.numbers.len == 25); try std.testing.expect(mat.numbers.len == 25);
} }
test "matrix get 2 dim" { test "matrix get" {
const mat = try Matrix.init( const mat = try Matrix.init(
std.testing.allocator, std.testing.allocator,
&.{ 3, 3 }, 3,
3,
&.{ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
); );
defer mat.deinit(); defer mat.deinit();
try std.testing.expect(try mat.get(&.{ 0, 0 }) == 1); try std.testing.expect(try mat.get(.{ 0, 0 }) == 1);
try std.testing.expect(try mat.get(&.{ 1, 0 }) == 2); try std.testing.expect(try mat.get(.{ 0, 1 }) == 2);
try std.testing.expect(try mat.get(&.{ 2, 0 }) == 3); try std.testing.expect(try mat.get(.{ 0, 2 }) == 3);
try std.testing.expect(try mat.get(&.{ 0, 1 }) == 4); try std.testing.expect(try mat.get(.{ 1, 0 }) == 4);
try std.testing.expect(try mat.get(&.{ 1, 1 }) == 5); try std.testing.expect(try mat.get(.{ 1, 1 }) == 5);
try std.testing.expect(try mat.get(&.{ 2, 1 }) == 6); try std.testing.expect(try mat.get(.{ 1, 2 }) == 6);
try std.testing.expect(try mat.get(&.{ 0, 2 }) == 7); try std.testing.expect(try mat.get(.{ 2, 0 }) == 7);
try std.testing.expect(try mat.get(&.{ 1, 2 }) == 8); try std.testing.expect(try mat.get(.{ 2, 1 }) == 8);
try std.testing.expect(try mat.get(&.{ 2, 2 }) == 9); try std.testing.expect(try mat.get(.{ 2, 2 }) == 9);
}
test "matrix get 3 dim" {
const mat = try Matrix.init(
std.testing.allocator,
&.{ 2, 2, 2 },
&.{ 1, 2, 3, 4, 5, 6, 7, 8 },
);
defer mat.deinit();
try std.testing.expect(try mat.get(&.{ 0, 0, 0 }) == 1);
try std.testing.expect(try mat.get(&.{ 1, 0, 0 }) == 2);
try std.testing.expect(try mat.get(&.{ 0, 1, 0 }) == 3);
try std.testing.expect(try mat.get(&.{ 1, 1, 0 }) == 4);
try std.testing.expect(try mat.get(&.{ 0, 0, 1 }) == 5);
try std.testing.expect(try mat.get(&.{ 1, 0, 1 }) == 6);
try std.testing.expect(try mat.get(&.{ 0, 1, 1 }) == 7);
try std.testing.expect(try mat.get(&.{ 1, 1, 1 }) == 8);
} }
test "matrix fill" { test "matrix fill" {
const mat = try Matrix.init(std.testing.allocator, &.{ 5, 5 }, &.{}); const mat = try Matrix.init(std.testing.allocator, 5, 5, &.{});
defer mat.deinit(); defer mat.deinit();
mat.fill(123); mat.fill(123);
@ -190,9 +230,9 @@ test "matrix fill" {
} }
test "matrix add" { test "matrix add" {
const mat1 = try Matrix.init(std.testing.allocator, &.{ 2, 2 }, &.{}); const mat1 = try Matrix.init(std.testing.allocator, 2, 2, &.{});
defer mat1.deinit(); defer mat1.deinit();
const mat2 = try Matrix.init(std.testing.allocator, &.{ 2, 2 }, &.{}); const mat2 = try Matrix.init(std.testing.allocator, 2, 2, &.{});
defer mat2.deinit(); defer mat2.deinit();
mat1.fill(1); mat1.fill(1);
@ -208,9 +248,9 @@ test "matrix add" {
} }
test "matrix sub" { test "matrix sub" {
const mat1 = try Matrix.init(std.testing.allocator, &.{ 2, 2 }, &.{}); const mat1 = try Matrix.init(std.testing.allocator, 2, 2, &.{});
defer mat1.deinit(); defer mat1.deinit();
const mat2 = try Matrix.init(std.testing.allocator, &.{ 2, 2 }, &.{}); const mat2 = try Matrix.init(std.testing.allocator, 2, 2, &.{});
defer mat2.deinit(); defer mat2.deinit();
mat1.fill(1); mat1.fill(1);
@ -228,7 +268,8 @@ test "matrix sub" {
test "matrix neg" { test "matrix neg" {
const mat = try Matrix.init( const mat = try Matrix.init(
std.testing.allocator, std.testing.allocator,
&.{ 3, 3 }, 3,
3,
&.{ 1, -2, 3, -4, 5, -6, 7, -8, 9 }, &.{ 1, -2, 3, -4, 5, -6, 7, -8, 9 },
); );
defer mat.deinit(); defer mat.deinit();
@ -240,3 +281,65 @@ test "matrix neg" {
try std.testing.expect(std.mem.eql(i64, mat_neg.numbers, expected_numbers)); try std.testing.expect(std.mem.eql(i64, mat_neg.numbers, expected_numbers));
} }
test "matrix row iterator" {
const mat = try Matrix.init(
std.testing.allocator,
3,
3,
&.{ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
);
defer mat.deinit();
var row_iter = mat.row(0);
try std.testing.expect(row_iter.next() == 1);
try std.testing.expect(row_iter.next() == 2);
try std.testing.expect(row_iter.next() == 3);
try std.testing.expect(row_iter.next() == null);
row_iter = mat.row(1);
try std.testing.expect(row_iter.next() == 4);
try std.testing.expect(row_iter.next() == 5);
try std.testing.expect(row_iter.next() == 6);
try std.testing.expect(row_iter.next() == null);
row_iter = mat.row(2);
try std.testing.expect(row_iter.next() == 7);
try std.testing.expect(row_iter.next() == 8);
try std.testing.expect(row_iter.next() == 9);
try std.testing.expect(row_iter.next() == null);
}
test "matrix col iterator" {
const mat = try Matrix.init(
std.testing.allocator,
3,
3,
&.{ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
);
defer mat.deinit();
var col_iter = mat.col(0);
try std.testing.expect(col_iter.next() == 1);
try std.testing.expect(col_iter.next() == 4);
try std.testing.expect(col_iter.next() == 7);
try std.testing.expect(col_iter.next() == null);
col_iter = mat.col(1);
try std.testing.expect(col_iter.next() == 2);
try std.testing.expect(col_iter.next() == 5);
try std.testing.expect(col_iter.next() == 8);
try std.testing.expect(col_iter.next() == null);
col_iter = mat.col(2);
try std.testing.expect(col_iter.next() == 3);
try std.testing.expect(col_iter.next() == 6);
try std.testing.expect(col_iter.next() == 9);
try std.testing.expect(col_iter.next() == null);
}