diff --git a/zig-matrix/src/main.zig b/zig-matrix/src/main.zig index eecb2e1..083d601 100644 --- a/zig-matrix/src/main.zig +++ b/zig-matrix/src/main.zig @@ -163,6 +163,37 @@ const Matrix = struct { return result; } + pub fn mul(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix { + if (self.cols != other.rows) { + return error.IncompatibleDims; + } + + var result: Matrix = try Matrix.init( + allocator, + self.rows, + other.cols, + &.{}, + ); + + var currRow: u64 = 0; + while (currRow < result.rows) : (currRow += 1) { + var currCol: u64 = 0; + while (currCol < result.cols) : (currCol += 1) { + var currSum: i64 = 0; + var rowIter = self.row(currRow); + var colIter = other.col(currCol); + + while (rowIter.next()) |r| { + currSum += r * (colIter.next() orelse unreachable); + } + + try result.set(.{ currRow, currCol }, currSum); + } + } + + return result; + } + pub fn row(self: Matrix, rowIdx: u64) RowIterator { return RowIterator.init( self, @@ -343,3 +374,31 @@ test "matrix col iterator" { try std.testing.expect(col_iter.next() == 9); try std.testing.expect(col_iter.next() == null); } + +test "matrix mul" { + const mat1 = try Matrix.init( + std.testing.allocator, + 2, + 3, + &.{ 1, 2, 3, 4, 5, 6 }, + ); + defer mat1.deinit(); + + const mat2 = try Matrix.init( + std.testing.allocator, + 3, + 2, + &.{ 1, 2, 3, 4, 5, 6 }, + ); + defer mat2.deinit(); + + const mat3 = try mat1.mul(std.testing.allocator, mat2); + defer mat3.deinit(); + + try std.testing.expect(mat3.rows == 2); + try std.testing.expect(mat3.cols == 2); + try std.testing.expect(try mat3.get(.{ 0, 0 }) == 22); + try std.testing.expect(try mat3.get(.{ 0, 1 }) == 28); + try std.testing.expect(try mat3.get(.{ 1, 0 }) == 49); + try std.testing.expect(try mat3.get(.{ 1, 1 }) == 64); +}