Add multiplication to Zig Matrix
This commit is contained in:
parent
09b4c503b9
commit
d74d89995f
@ -163,6 +163,37 @@ const Matrix = struct {
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn mul(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix {
|
||||
if (self.cols != other.rows) {
|
||||
return error.IncompatibleDims;
|
||||
}
|
||||
|
||||
var result: Matrix = try Matrix.init(
|
||||
allocator,
|
||||
self.rows,
|
||||
other.cols,
|
||||
&.{},
|
||||
);
|
||||
|
||||
var currRow: u64 = 0;
|
||||
while (currRow < result.rows) : (currRow += 1) {
|
||||
var currCol: u64 = 0;
|
||||
while (currCol < result.cols) : (currCol += 1) {
|
||||
var currSum: i64 = 0;
|
||||
var rowIter = self.row(currRow);
|
||||
var colIter = other.col(currCol);
|
||||
|
||||
while (rowIter.next()) |r| {
|
||||
currSum += r * (colIter.next() orelse unreachable);
|
||||
}
|
||||
|
||||
try result.set(.{ currRow, currCol }, currSum);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn row(self: Matrix, rowIdx: u64) RowIterator {
|
||||
return RowIterator.init(
|
||||
self,
|
||||
@ -343,3 +374,31 @@ test "matrix col iterator" {
|
||||
try std.testing.expect(col_iter.next() == 9);
|
||||
try std.testing.expect(col_iter.next() == null);
|
||||
}
|
||||
|
||||
test "matrix mul" {
|
||||
const mat1 = try Matrix.init(
|
||||
std.testing.allocator,
|
||||
2,
|
||||
3,
|
||||
&.{ 1, 2, 3, 4, 5, 6 },
|
||||
);
|
||||
defer mat1.deinit();
|
||||
|
||||
const mat2 = try Matrix.init(
|
||||
std.testing.allocator,
|
||||
3,
|
||||
2,
|
||||
&.{ 1, 2, 3, 4, 5, 6 },
|
||||
);
|
||||
defer mat2.deinit();
|
||||
|
||||
const mat3 = try mat1.mul(std.testing.allocator, mat2);
|
||||
defer mat3.deinit();
|
||||
|
||||
try std.testing.expect(mat3.rows == 2);
|
||||
try std.testing.expect(mat3.cols == 2);
|
||||
try std.testing.expect(try mat3.get(.{ 0, 0 }) == 22);
|
||||
try std.testing.expect(try mat3.get(.{ 0, 1 }) == 28);
|
||||
try std.testing.expect(try mat3.get(.{ 1, 0 }) == 49);
|
||||
try std.testing.expect(try mat3.get(.{ 1, 1 }) == 64);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user