mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add OnnxToTorch lowering for ReduceL1 Op (#3146)
Adds OnnxToTorch Lowering for the ReduceL1 op.pull/3170/head
parent
af5509c5d9
commit
a0232e9ebd
|
@ -39,6 +39,127 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||||
rewriter.getType<T>(), ofItem);
|
rewriter.getType<T>(), ofItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In case the ReduceSum Op was not the first operation performed on the data,
|
||||||
|
// we provide the original operand through storeResult, which will be modified
|
||||||
|
// if the result will be passed onto another operation, and will be used for
|
||||||
|
// noop_with_empty_axes handling before that.
|
||||||
|
LogicalResult reducedSumImpl(OpBinder binder,
|
||||||
|
ConversionPatternRewriter &rewriter, Value data,
|
||||||
|
Torch::ValueTensorType resultType,
|
||||||
|
Value &storeResult, int64_t keepDims,
|
||||||
|
int64_t noop_with_empty_axes,
|
||||||
|
bool isIntermediateOp) {
|
||||||
|
|
||||||
|
SmallVector<Value> axesList;
|
||||||
|
Value axesVal;
|
||||||
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||||
|
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||||
|
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: expected input and result to have shapes");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
||||||
|
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
||||||
|
SmallVector<int64_t> resultShape{resultType.getSizes()};
|
||||||
|
// if the shapes are equal, none of the dims is reduced
|
||||||
|
if (llvm::equal(inputShape, resultShape)) {
|
||||||
|
// simply fill in the op and return
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (areAllElementsDistinct(inputShape)) {
|
||||||
|
// The check for the input shape elements to be distinct is added
|
||||||
|
// for the cases like:
|
||||||
|
// Input: [3, 2, 2] -> Output: [3, 2]
|
||||||
|
// For the above case, from the input and output shape it can't be
|
||||||
|
// inferred whether the dim:1 is reduced or dim:2. To avoid these
|
||||||
|
// type of cases, the check has been placed.
|
||||||
|
SmallVector<int64_t> reduceDims;
|
||||||
|
unsigned resultShapeCounter = 0;
|
||||||
|
for (unsigned i = 0; i < inputShape.size(); i++) {
|
||||||
|
if (resultShapeCounter < resultShape.size() &&
|
||||||
|
inputShape[i] == resultShape[resultShapeCounter]) {
|
||||||
|
resultShapeCounter++;
|
||||||
|
} else {
|
||||||
|
reduceDims.push_back(i);
|
||||||
|
if (resultShapeCounter < resultShape.size() &&
|
||||||
|
resultShape[resultShapeCounter] == 1)
|
||||||
|
resultShapeCounter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto i : reduceDims) {
|
||||||
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (axesList.empty()) {
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axesVal.getType().cast<Torch::BaseTensorType>();
|
||||||
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||||
|
auto axesShape = axesTy.getSizes();
|
||||||
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
SmallVector<int64_t> selectSizes{1};
|
||||||
|
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
int64_t numAxes = axesShape[0];
|
||||||
|
for (int64_t i = 0; i < numAxes; ++i) {
|
||||||
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getI64IntegerAttr(i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selType, axesVal, zero, iv);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
axesList.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> axesInts;
|
||||||
|
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
||||||
|
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
||||||
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getI64IntegerAttr(axesInts[i]));
|
||||||
|
axesList.push_back(iv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not include absolute value in the noop
|
||||||
|
if (axesList.empty() && noop_with_empty_axes) {
|
||||||
|
rewriter.replaceOp(binder.op, storeResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
axesList);
|
||||||
|
Value keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||||
|
Value dType = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
// If we are using the ReducedSum as an intermediate op to be passed into
|
||||||
|
// another operation, we might not want to replace the Op. So we create a new
|
||||||
|
// Op and store the result in a variable.
|
||||||
|
if (!isIntermediateOp) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||||
|
/*dtype=*/dType);
|
||||||
|
} else {
|
||||||
|
storeResult = rewriter.create<Torch::AtenSumDimIntListOp>(
|
||||||
|
binder.getLoc(), resultType, data, dimValueList, keepDimBool,
|
||||||
|
/*dtype=*/dType);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
@ -758,124 +879,41 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp("ReduceL1", 1,
|
||||||
"ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value data;
|
int64_t keepDims, noop_with_empty_axes;
|
||||||
int64_t keepDims, noop_with_empty_axes;
|
Value operand;
|
||||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
if (binder.tensorOperandAtIndex(operand, 0) ||
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType) ||
|
||||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||||
0))
|
"noop_with_empty_axes", 0))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<Value> axesList;
|
Value data = rewriter.create<Torch::AtenAbsOp>(
|
||||||
|
binder.getLoc(), operand.getType(), operand);
|
||||||
|
|
||||||
Value axesVal;
|
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
/*storeValue=*/operand, keepDims,
|
||||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
noop_with_empty_axes, false);
|
||||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
});
|
||||||
return rewriter.notifyMatchFailure(
|
patterns.onOp("ReduceSum", 1,
|
||||||
binder.op,
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
"unimplemented: expected input and result to have shapes");
|
Torch::ValueTensorType resultType;
|
||||||
}
|
Value data;
|
||||||
|
int64_t keepDims, noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||||
|
"noop_with_empty_axes", 0))
|
||||||
|
return failure();
|
||||||
|
|
||||||
// If the input shape and result shape is statically known then the
|
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||||
// list of dims to be squeezed can be derived from those shapes. As a
|
/*storeValue=*/data, keepDims,
|
||||||
// result, we don't have to wait for the dim values to be known at
|
noop_with_empty_axes, false);
|
||||||
// runtime which is also expected by the downstream pipeline.
|
});
|
||||||
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
|
||||||
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
|
||||||
SmallVector<int64_t> resultShape{resultType.getSizes()};
|
|
||||||
if (llvm::equal(inputShape, resultShape)) {
|
|
||||||
// Case: none of the dimension is reduced.
|
|
||||||
rewriter.replaceOp(binder.op, data);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
if (areAllElementsDistinct(inputShape)) {
|
|
||||||
// The check for the input shape elements to be distinct is added
|
|
||||||
// for the cases like:
|
|
||||||
// Input: [3, 2, 2] -> Output: [3, 2]
|
|
||||||
// For the above case, from the input and output shape it can't be
|
|
||||||
// inferred whether the dim:1 is reduced or dim:2. To avoid these
|
|
||||||
// type of cases, the check has been placed.
|
|
||||||
SmallVector<int64_t> reduceDims;
|
|
||||||
unsigned resultShapeCounter = 0;
|
|
||||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
|
||||||
if (resultShapeCounter < resultShape.size() &&
|
|
||||||
inputShape[i] == resultShape[resultShapeCounter]) {
|
|
||||||
resultShapeCounter++;
|
|
||||||
} else {
|
|
||||||
reduceDims.push_back(i);
|
|
||||||
if (resultShapeCounter < resultShape.size() &&
|
|
||||||
resultShape[resultShapeCounter] == 1)
|
|
||||||
resultShapeCounter++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto i : reduceDims) {
|
|
||||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axesList.empty()) {
|
|
||||||
Torch::BaseTensorType axesType =
|
|
||||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
|
||||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
|
||||||
auto axesShape = axesTy.getSizes();
|
|
||||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getI64IntegerAttr(0));
|
|
||||||
SmallVector<int64_t> selectSizes{1};
|
|
||||||
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
|
||||||
selectSizes, axesType.getOptionalDtype());
|
|
||||||
int64_t numAxes = axesShape[0];
|
|
||||||
for (int64_t i = 0; i < numAxes; ++i) {
|
|
||||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getI64IntegerAttr(i));
|
|
||||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
||||||
binder.getLoc(), selType, axesVal, zero, iv);
|
|
||||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
||||||
axesList.push_back(dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> axesInts;
|
|
||||||
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
|
||||||
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
|
||||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getI64IntegerAttr(axesInts[i]));
|
|
||||||
axesList.push_back(iv);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// deal with case when axes is empty
|
|
||||||
if (axesList.empty() && noop_with_empty_axes) {
|
|
||||||
rewriter.replaceOp(binder.op, data);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
||||||
axesList);
|
|
||||||
Value keepDimBool =
|
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
|
||||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
||||||
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
||||||
/*dtype=*/noneVal);
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"ReduceMean", 1,
|
"ReduceMean", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -2384,8 +2384,6 @@ ONNX_XFAIL_SET = {
|
||||||
"RandModule_basic",
|
"RandModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ReduceL1
|
# Failure - onnx_lowering: onnx.ReduceL1
|
||||||
"ReduceL1NormModule_basic",
|
|
||||||
"ReduceL1NormWithDTypeModule_basic",
|
|
||||||
"ReduceL1NormComplexModule_basic",
|
"ReduceL1NormComplexModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ReduceL2
|
# Failure - onnx_lowering: onnx.ReduceL2
|
||||||
|
@ -2529,6 +2527,13 @@ ONNX_XFAIL_SET = {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
|
||||||
|
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||||
|
# ERROR: Found dtype (torch.float64) but expected (torch.float32)
|
||||||
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
ONNX_CRASHING_SET = {
|
ONNX_CRASHING_SET = {
|
||||||
"FakeQuantizePerTensorAffineModule_basic",
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
|
|
@ -863,6 +863,59 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens
|
||||||
return %0 : !torch.vtensor<[4],i1>
|
return %0 : !torch.vtensor<[4],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example
|
||||||
|
func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_l1_keep_dims_example
|
||||||
|
func.func @test_reduce_l1_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_l1_do_not_keepdims_example
|
||||||
|
func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
||||||
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
|
Loading…
Reference in New Issue