mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for ReduceLogSum Op (#3229)
This commit adds the OnnxToTorch support for ReduceLogSum oppull/3238/head
parent
2eac8a992f
commit
ac11ec796d
|
@ -962,6 +962,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp("ReduceLogSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
|
||||
auto reducedSumBool =
|
||||
reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, true);
|
||||
|
||||
if (failed(reducedSumBool))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Failed to perform sum operation on square of operand");
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
|
||||
binder.op, resultType, data);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceMean", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -942,6 +942,57 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
|
||||
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
||||
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
|
||||
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
|
||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
|
|
Loading…
Reference in New Issue