mirror of https://github.com/llvm/torch-mlir
[torch] Support lowering `torch.item` to `tensor.extract` (#2835)
Extracting scalar values from tensors can be implemented via a lowering to tensor.extract.pull/2847/head
parent
8a17c98b74
commit
0114a570e3
|
@ -28,6 +28,47 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
namespace {
|
||||
|
||||
class ConvertAtenItemOp : public OpConversionPattern<AtenItemOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenItemOp>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenItemOp::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenItemOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto operand = adaptor.getOperands()[0];
|
||||
auto operandTy = cast<RankedTensorType>(operand.getType());
|
||||
auto torchDTy = cast<ValueTensorType>(op.getOperand().getType()).getDtype();
|
||||
|
||||
if (operandTy.getNumElements() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expected only one item");
|
||||
|
||||
auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
|
||||
auto rank = operandTy.getRank();
|
||||
llvm::SmallVector<Value> indices(rank, zeroIdx);
|
||||
|
||||
Value extract = rewriter.create<tensor::ExtractOp>(
|
||||
op.getLoc(), operandTy.getElementType(), operand, indices);
|
||||
auto extractTy = extract.getType();
|
||||
if (isa<mlir::IntegerType>(extractTy) && !extractTy.isInteger(64)) {
|
||||
if (torchDTy.isSignlessInteger()) {
|
||||
extract = rewriter.create<arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), extract);
|
||||
} else {
|
||||
extract = rewriter.create<arith::ExtSIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), extract);
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<mlir::FloatType>(extractTy) && !extractTy.isF64()) {
|
||||
extract = rewriter.create<arith::ExtFOp>(op.getLoc(),
|
||||
rewriter.getF64Type(), extract);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, extract);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertAtenShapeToTensorPatternOp
|
||||
: public OpConversionPattern<Aten_ShapeAsTensorOp> {
|
||||
public:
|
||||
|
@ -70,6 +111,7 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addIllegalOp<Torch::AtenItemOp>();
|
||||
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
@ -77,7 +119,8 @@ public:
|
|||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertAtenShapeToTensorPatternOp>(typeConverter, context);
|
||||
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp>(
|
||||
typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
@ -76,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());
|
||||
pm.addPass(createConvertTorchConversionToMLProgramPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
|
|
|
@ -130,6 +130,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: Tensor.item
|
||||
'AtenItemIntOpModule_basic',
|
||||
'AtenItemFpOpModule_basic',
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
||||
'SortIntListReverse_basic',
|
||||
|
||||
|
|
|
@ -428,3 +428,41 @@ class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
||||
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenItemIntOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int8, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenItemIntOpModule())
|
||||
def AtenItemIntOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenItemFpOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return float(val)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
||||
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
|
Loading…
Reference in New Issue