Add multiplication to Zig Matrix

This commit is contained in:
Maciej Krzyżanowski 2024-09-06 07:35:32 +02:00
parent 09b4c503b9
commit d74d89995f

View File

@ -163,6 +163,37 @@ const Matrix = struct {
return result; return result;
} }
pub fn mul(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix {
if (self.cols != other.rows) {
return error.IncompatibleDims;
}
var result: Matrix = try Matrix.init(
allocator,
self.rows,
other.cols,
&.{},
);
var currRow: u64 = 0;
while (currRow < result.rows) : (currRow += 1) {
var currCol: u64 = 0;
while (currCol < result.cols) : (currCol += 1) {
var currSum: i64 = 0;
var rowIter = self.row(currRow);
var colIter = other.col(currCol);
while (rowIter.next()) |r| {
currSum += r * (colIter.next() orelse unreachable);
}
try result.set(.{ currRow, currCol }, currSum);
}
}
return result;
}
pub fn row(self: Matrix, rowIdx: u64) RowIterator { pub fn row(self: Matrix, rowIdx: u64) RowIterator {
return RowIterator.init( return RowIterator.init(
self, self,
@ -343,3 +374,31 @@ test "matrix col iterator" {
try std.testing.expect(col_iter.next() == 9); try std.testing.expect(col_iter.next() == 9);
try std.testing.expect(col_iter.next() == null); try std.testing.expect(col_iter.next() == null);
} }
test "matrix mul" {
const mat1 = try Matrix.init(
std.testing.allocator,
2,
3,
&.{ 1, 2, 3, 4, 5, 6 },
);
defer mat1.deinit();
const mat2 = try Matrix.init(
std.testing.allocator,
3,
2,
&.{ 1, 2, 3, 4, 5, 6 },
);
defer mat2.deinit();
const mat3 = try mat1.mul(std.testing.allocator, mat2);
defer mat3.deinit();
try std.testing.expect(mat3.rows == 2);
try std.testing.expect(mat3.cols == 2);
try std.testing.expect(try mat3.get(.{ 0, 0 }) == 22);
try std.testing.expect(try mat3.get(.{ 0, 1 }) == 28);
try std.testing.expect(try mat3.get(.{ 1, 0 }) == 49);
try std.testing.expect(try mat3.get(.{ 1, 1 }) == 64);
}