torch-mlir/lib/Conversion/TorchToStablehlo
Xinyu Yang ae4724763a
[Stablehlo] Enhance broadcast pattern in matmul Ops (#3161)
To pass test "MatmulStaticBroadcast_basic" in stablehlo:
```python
class MatmulStaticBroadcast(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([4, 1, 6, 7], torch.float32, True),
        ([8, 1, 5, 7, 6], torch.float32, True),
    ])
    def forward(self, lhs, rhs):
        return torch.matmul(lhs, rhs)


@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
    module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
```
2024-04-16 10:10:36 +08:00
..
Basic.cpp Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156) 2024-04-15 13:45:10 -07:00
CMakeLists.txt [Stablehlo] add torch_to_stablehlo::getBackendTypeForScalarType (#2975) 2024-03-04 23:31:54 +08:00
GatherScatter.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
Linear.cpp [Stablehlo] Enhance broadcast pattern in matmul Ops (#3161) 2024-04-16 10:10:36 +08:00
Pooling.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
PopulatePatterns.h [StableHLO] Support for slice_scatter (#1960) 2023-03-22 13:41:04 -07:00
Reduction.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
StablehloLegalizeUtils.cpp [stablehlo] Reduce unnecessary template specialization code (#3047) 2024-04-01 14:18:49 -07:00
TorchToStablehlo.cpp [Stablehlo] support dynamic shape when convert aten.fill.Scalar (#2349) 2023-07-27 18:35:25 +08:00
Utils.cpp [Stablehlo] add torch_to_stablehlo::getBackendTypeForScalarType (#2975) 2024-03-04 23:31:54 +08:00
Utils.h [Stablehlo] add torch_to_stablehlo::getBackendTypeForScalarType (#2975) 2024-03-04 23:31:54 +08:00
ViewLike.cpp [Stablehlo] lowering aten.view to shape.num_elements + stablehlo.comp… (#3125) 2024-04-09 14:54:57 +08:00