mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.Hardmax` lowering to torch (#3624)
The lowering to torch makes assumption about the dimensions / types of reduce max and onehot. We need to correct for expected torch behavior.pull/3629/head
parent
026dfade64
commit
d3695a97a0
|
@ -3122,7 +3122,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
});
|
||||
|
||||
patterns.onOp(
|
||||
"Hardmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Hardmax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// onnx.Hardmax can be expanded into the following python code:
|
||||
//
|
||||
// import torch.nn.functional as F
|
||||
|
@ -3143,33 +3143,64 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
int64_t axisValue;
|
||||
Value input, axis;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(axisValue, "axis") ||
|
||||
binder.s64IntegerAttr(axisValue, "axis", -1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto loc = binder.getLoc();
|
||||
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||
|
||||
if (axisValue < 0)
|
||||
axisValue += inputTy.getSizes().size();
|
||||
|
||||
std::optional<int64_t> axisIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(axisValue);
|
||||
if (!axisIntTorch.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for the given axis conversion");
|
||||
axis = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(axisIntTorch.value()));
|
||||
loc, rewriter.getI64IntegerAttr(axisValue));
|
||||
|
||||
// torch.argmax
|
||||
Value constKeepDims = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
SmallVector<int64_t> argmaxShape;
|
||||
for (int i = 0, s = inputTy.getSizes().size(); i < s; ++i) {
|
||||
if (i == axisValue)
|
||||
continue;
|
||||
argmaxShape.push_back(inputTy.getSizes()[i]);
|
||||
}
|
||||
|
||||
auto argmaxTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
argmaxShape, rewriter.getIntegerType(32, IntegerType::Signed));
|
||||
Value argmax = rewriter.create<Torch::AtenArgmaxOp>(
|
||||
loc, resultType, input, axis, constKeepDims);
|
||||
loc, argmaxTy, input, axis, constKeepDims);
|
||||
|
||||
// one_hot
|
||||
Value oneInt = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenOneHotOp>(binder.op, resultType,
|
||||
argmax, oneInt);
|
||||
SmallVector<int64_t> onehotShape(argmaxShape);
|
||||
onehotShape.push_back(inputTy.getSizes()[axisValue]);
|
||||
auto onehotTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
onehotShape, resultType.getDtype());
|
||||
Value numClasses =
|
||||
rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(), input, axis);
|
||||
Value onehot = rewriter.create<Torch::AtenOneHotOp>(
|
||||
binder.getLoc(), onehotTy, argmax, numClasses);
|
||||
|
||||
SmallVector<int64_t> permutation;
|
||||
for (int i = 0; i < axisValue; ++i)
|
||||
permutation.push_back(i);
|
||||
permutation.push_back(onehotShape.size() - 1);
|
||||
for (int i = axisValue, s = onehotShape.size(); i < s - 1; ++i)
|
||||
permutation.push_back(i);
|
||||
|
||||
SmallVector<Value> permValues;
|
||||
for (auto d : permutation) {
|
||||
permValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(d)));
|
||||
}
|
||||
|
||||
Value permuteDims = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
permValues);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenPermuteOp>(binder.op, resultType,
|
||||
onehot, permuteDims);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("LpNormalization", 1,
|
||||
|
|
|
@ -1471,9 +1471,18 @@ func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<
|
|||
|
||||
// CHECK-LABEL: func.func @test_hardmax
|
||||
func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %int6, %false : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.one_hot %[[ARGMAX]], %int1 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[FALSE]] = torch.constant.bool false
|
||||
// CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %[[AXIS]], %[[FALSE]]
|
||||
// CHECK: %[[CLASSES:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ONEHOT:.+]] = torch.aten.one_hot %[[ARGMAX]], %[[CLASSES]]
|
||||
// CHECK: %[[PERM0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[PERM2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[PERM1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[PERM0]], %[[PERM2]], %[[PERM1]]
|
||||
// CHECK: %[[PERMUTE:.+]] = torch.aten.permute %[[ONEHOT]], %[[LIST]]
|
||||
// CHECK: return %[[PERMUTE]]
|
||||
%0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
|
@ -1510,16 +1519,6 @@ func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1:
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_hardmax
|
||||
func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %int6, %false : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.one_hot %[[ARGMAX]], %int1 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_lpnormalization
|
||||
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
|
|
Loading…
Reference in New Issue