mirror of https://github.com/llvm/torch-mlir
[onnx] Adding lowering for `onnx.Size` operation (#2985)
We can support `onnx.Size` by requesing the size of each dimensions and taking the product of the results, then packing it into a tensor. --------- Co-authored-by: Scott Todd <scott.todd0@gmail.com>pull/2992/head
parent
a78659742a
commit
c15f1a2bd2
|
@ -2032,6 +2032,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
none, none, none);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto loc = binder.getLoc();
|
||||
auto &op = binder.op;
|
||||
auto operandTy = cast<Torch::BaseTensorType>(operand.getType());
|
||||
|
||||
if (!operandTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "input rank unknown");
|
||||
|
||||
llvm::SmallVector<Value> dims;
|
||||
int64_t rank = operandTy.getSizes().size();
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
Value dim = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), operand, iv);
|
||||
dims.push_back(dim);
|
||||
}
|
||||
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
||||
if (dims.empty()) {
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
||||
op, resultType, one, none, none, cstFalse);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value prod = dims[0];
|
||||
for (int i = 1, s = dims.size(); i < s; ++i)
|
||||
prod = rewriter.create<Torch::AtenMulIntOp>(loc, prod, dims[i]);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
||||
op, resultType, prod, none, none, cstFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -1649,3 +1649,24 @@ func.func @test_sign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
|||
%0 = torch.operator "onnx.Sign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_size
|
||||
func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64} {
|
||||
// CHECK-DAG %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK-DAG %[[D0:.+]] = torch.aten.size.int %arg0, %[[INT0]]
|
||||
// CHECK-DAG %[[D1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||
// CHECK-DAG %[[D2:.+]] = torch.aten.size.int %arg0, %[[INT2]]
|
||||
// CHECK-DAG %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG %[[MUL0:.+]] = torch.aten.mul.int %[[D0]], %[[D1]]
|
||||
// CHECK-DAG %[[MUL1:.+]] = torch.aten.mul.int %[[MUL0]], %[[D3]]
|
||||
// CHECK-DAG %[[TENSOR:.+]] = torch.aten.tensor.int %[[MUL1]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||
// CHECK return %[[TENSOR]]
|
||||
%0 = torch.operator "onnx.Size"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32>
|
||||
return %0 : !torch.vtensor<[],si32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue