mirror of https://github.com/llvm/torch-mlir
[Torch] support 1d aten tensor shape and dtype infer (#3776)
parent
ab62f35373
commit
b176939808
|
@ -46,6 +46,62 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto context = op.getContext();
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto result = op.getResult();
|
||||||
|
auto resultType = cast<BaseTensorType>(result.getType());
|
||||||
|
if (resultType.hasSizes() && resultType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The result of aten.tensor is already a BaseTensorType.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputList = op.getOperand(0);
|
||||||
|
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
|
||||||
|
if (!listConstruct) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently only support the 1d input list.
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
sizes.push_back(listConstruct->getOperands().size());
|
||||||
|
FailureOr<Type> torchType;
|
||||||
|
auto eleType = listConstruct->getOperands()[0].getType();
|
||||||
|
if (isa<Torch::IntType>(eleType)) {
|
||||||
|
torchType = getTypeForScalarType(op->getContext(),
|
||||||
|
torch_upstream::ScalarType::Long);
|
||||||
|
} else if (isa<Torch::FloatType>(eleType)) {
|
||||||
|
torchType = getTypeForScalarType(op->getContext(),
|
||||||
|
torch_upstream::ScalarType::Float);
|
||||||
|
} else {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only support Int and Float Type.");
|
||||||
|
}
|
||||||
|
auto newResultType = ValueTensorType::get(context, sizes, *torchType);
|
||||||
|
|
||||||
|
Value originalTypedValue;
|
||||||
|
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
|
||||||
|
if (!originalTypedValue) {
|
||||||
|
rewriter.setInsertionPointAfter(op);
|
||||||
|
originalTypedValue =
|
||||||
|
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
|
||||||
|
}
|
||||||
|
use.set(originalTypedValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.setType(newResultType);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||||
int resultNum,
|
int resultNum,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
|
||||||
populateFoldPrimUncheckedCastOpPattern(patterns, context);
|
populateFoldPrimUncheckedCastOpPattern(patterns, context);
|
||||||
patterns.insert<DecomposeAtenSizeOp>(context);
|
patterns.insert<DecomposeAtenSizeOp>(context);
|
||||||
patterns.insert<RefineShapeCalculateOp>(context);
|
patterns.insert<RefineShapeCalculateOp>(context);
|
||||||
|
patterns.insert<InferTensorOp>(context);
|
||||||
|
|
||||||
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
||||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
|
@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TensorAlloc1dStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 4, 6], torch.int, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
res = torch.tensor([x.shape[0]])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
|
||||||
|
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 6))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ScalarTensorFloat32Module(torch.nn.Module):
|
class ScalarTensorFloat32Module(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue