Add multiplication to Zig Matrix
This commit is contained in:
parent
09b4c503b9
commit
d74d89995f
@ -163,6 +163,37 @@ const Matrix = struct {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mul(self: Matrix, allocator: std.mem.Allocator, other: Matrix) !Matrix {
|
||||||
|
if (self.cols != other.rows) {
|
||||||
|
return error.IncompatibleDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
var result: Matrix = try Matrix.init(
|
||||||
|
allocator,
|
||||||
|
self.rows,
|
||||||
|
other.cols,
|
||||||
|
&.{},
|
||||||
|
);
|
||||||
|
|
||||||
|
var currRow: u64 = 0;
|
||||||
|
while (currRow < result.rows) : (currRow += 1) {
|
||||||
|
var currCol: u64 = 0;
|
||||||
|
while (currCol < result.cols) : (currCol += 1) {
|
||||||
|
var currSum: i64 = 0;
|
||||||
|
var rowIter = self.row(currRow);
|
||||||
|
var colIter = other.col(currCol);
|
||||||
|
|
||||||
|
while (rowIter.next()) |r| {
|
||||||
|
currSum += r * (colIter.next() orelse unreachable);
|
||||||
|
}
|
||||||
|
|
||||||
|
try result.set(.{ currRow, currCol }, currSum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn row(self: Matrix, rowIdx: u64) RowIterator {
|
pub fn row(self: Matrix, rowIdx: u64) RowIterator {
|
||||||
return RowIterator.init(
|
return RowIterator.init(
|
||||||
self,
|
self,
|
||||||
@ -343,3 +374,31 @@ test "matrix col iterator" {
|
|||||||
try std.testing.expect(col_iter.next() == 9);
|
try std.testing.expect(col_iter.next() == 9);
|
||||||
try std.testing.expect(col_iter.next() == null);
|
try std.testing.expect(col_iter.next() == null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test "matrix mul" {
|
||||||
|
const mat1 = try Matrix.init(
|
||||||
|
std.testing.allocator,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
&.{ 1, 2, 3, 4, 5, 6 },
|
||||||
|
);
|
||||||
|
defer mat1.deinit();
|
||||||
|
|
||||||
|
const mat2 = try Matrix.init(
|
||||||
|
std.testing.allocator,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
&.{ 1, 2, 3, 4, 5, 6 },
|
||||||
|
);
|
||||||
|
defer mat2.deinit();
|
||||||
|
|
||||||
|
const mat3 = try mat1.mul(std.testing.allocator, mat2);
|
||||||
|
defer mat3.deinit();
|
||||||
|
|
||||||
|
try std.testing.expect(mat3.rows == 2);
|
||||||
|
try std.testing.expect(mat3.cols == 2);
|
||||||
|
try std.testing.expect(try mat3.get(.{ 0, 0 }) == 22);
|
||||||
|
try std.testing.expect(try mat3.get(.{ 0, 1 }) == 28);
|
||||||
|
try std.testing.expect(try mat3.get(.{ 1, 0 }) == 49);
|
||||||
|
try std.testing.expect(try mat3.get(.{ 1, 1 }) == 64);
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user