mirror of https://github.com/llvm/torch-mlir
[Torch] support 1d aten tensor shape and dtype infer (#3776)
parent
ab62f35373
commit
b176939808
|
@ -46,6 +46,62 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto context = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto result = op.getResult();
|
||||
auto resultType = cast<BaseTensorType>(result.getType());
|
||||
if (resultType.hasSizes() && resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The result of aten.tensor is already a BaseTensorType.");
|
||||
}
|
||||
|
||||
auto inputList = op.getOperand(0);
|
||||
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
|
||||
}
|
||||
|
||||
// Currently only support the 1d input list.
|
||||
SmallVector<int64_t> sizes;
|
||||
sizes.push_back(listConstruct->getOperands().size());
|
||||
FailureOr<Type> torchType;
|
||||
auto eleType = listConstruct->getOperands()[0].getType();
|
||||
if (isa<Torch::IntType>(eleType)) {
|
||||
torchType = getTypeForScalarType(op->getContext(),
|
||||
torch_upstream::ScalarType::Long);
|
||||
} else if (isa<Torch::FloatType>(eleType)) {
|
||||
torchType = getTypeForScalarType(op->getContext(),
|
||||
torch_upstream::ScalarType::Float);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only support Int and Float Type.");
|
||||
}
|
||||
auto newResultType = ValueTensorType::get(context, sizes, *torchType);
|
||||
|
||||
Value originalTypedValue;
|
||||
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
|
||||
if (!originalTypedValue) {
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
originalTypedValue =
|
||||
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
|
||||
}
|
||||
use.set(originalTypedValue);
|
||||
}
|
||||
|
||||
result.setType(newResultType);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||
int resultNum,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
|
|||
populateFoldPrimUncheckedCastOpPattern(patterns, context);
|
||||
patterns.insert<DecomposeAtenSizeOp>(context);
|
||||
patterns.insert<RefineShapeCalculateOp>(context);
|
||||
patterns.insert<InferTensorOp>(context);
|
||||
|
||||
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
|
|
@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorAlloc1dStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 4, 6], torch.int, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
res = torch.tensor([x.shape[0]])
|
||||
return res
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
|
||||
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ScalarTensorFloat32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue