mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchToTosa lowering for aten.where.self op (#1454)
parent
943cc9e736
commit
ad6f5848cb
|
@ -455,6 +455,7 @@ TOSA_PASS_SET = {
|
|||
"ArgmaxModule_keepDim",
|
||||
"ArgmaxModule_with_dim",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
|
|
|
@ -3004,6 +3004,30 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||
AtenWhereSelfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
auto condType = adaptor.condition().getType().dyn_cast<TensorType>();
|
||||
if (!condType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types condition are currently supported");
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, adaptor.condition(),
|
||||
adaptor.self(), adaptor.other());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||
AtenArangeStartStepOp op, OpAdaptor adaptor,
|
||||
|
@ -3829,6 +3853,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
|
||||
|
|
|
@ -134,6 +134,30 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenWhereSelfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 5, 5], torch.bool, True),
|
||||
([1, 12, 5, 5], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
return torch.ops.aten.where(a, b, c)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule())
|
||||
def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(()))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereSelfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -913,3 +913,20 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
|
|||
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
|
||||
return %0 : !torch.vtensor<[3,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.where.self(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,12,5,5],f32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
|
||||
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,12,5,5],f32>
|
||||
}
|
Loading…
Reference in New Issue