Add Scalarization Patterns for `AtenToDtypeOp`, `AtenNegOp`, `AtenRemainderTensorOp` (#3861)

1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith.
2. adds a scalarization pattern for `aten.neg` and
`aten.remainder.Tensor` ops.
3. improves folding of `aten.mul.int`
4. adds a scalarization pattern for `aten.to.dtype` which relies on
scalar cast ops and basic C++ casting between `double` and `int64_t`.
5. improves rank-0 case handling for `FoldAtenSplatPattern`
6. removes a bug with `aten.unflatten.int` decomposition incorrectly
generating a constant size int from a dynamic shape.
7. simplifies the dim list for `aten.unflatten.int` ops generated from
the `aten.view` canonicalization in scalarize shapes.

All of these changes were necessary to unblock
<https://github.com/iree-org/iree/issues/18899>.
main
zjgarvey 2024-11-12 14:25:02 -06:00 committed by GitHub
parent 889a836b3d
commit cd38ecf6c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 302 additions and 30 deletions

View File

@ -82,6 +82,25 @@ public:
};
} // namespace
namespace {
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
public:
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNegIntOp op,
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value a = adaptor.getA();
rewriter.replaceOpWithNewOp<arith::SubIOp>(
op,
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
/*bitwidth=*/64),
a);
return success();
}
};
} // namespace
namespace {
template <typename AtenOp, typename UnaryOp>
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
@ -465,11 +484,14 @@ public:
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);
target.addIllegalOp<AtenNegIntOp>();
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp>();
AtenMulIntOp, AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(

View File

@ -4068,6 +4068,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
if (lConstant && lhs == 1)
return getOperand(1);
if (rConstant && rhs == 1)
return getOperand(0);
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
return getI64IntegerAttr(getContext(), 0);
if (lConstant && rConstant)

View File

@ -4587,6 +4587,11 @@ public:
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
if (inputShape[dimInt] == Torch::kUnknownSize &&
llvm::count(sizesInts, -1) > 0)
return rewriter.notifyMatchFailure(
op, "Unimplemented: dynamic unflatten dim with an inferred size.");
SmallVector<Value> sizesTorchInt;
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
return rewriter.notifyMatchFailure(

View File

@ -714,7 +714,7 @@ public:
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Rank 0 item op prop
if (selfTy.getSizes().size() == 0) {
if (selfTy.getSizes().empty()) {
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
if (!squeezeDim && !numToTensor)
@ -746,6 +746,109 @@ public:
};
} // namespace
namespace {
LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
SmallVector<OpFoldResult> &converted,
SmallVector<OpFoldResult> &elements,
Type inputDtype, Type resultDtype) {
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
return failure();
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
return failure();
// if dtypes are both int or both float, no conversion needed
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
converted = elements;
return success();
}
if (resultIsInt) {
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
if (!eFloatAttr)
return failure();
converted.push_back(IntegerAttr::get(
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
}
return success();
}
// result is float
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
if (!eIntAttr)
return failure();
auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
: eIntAttr.getValue().getZExtValue();
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
}
return success();
}
class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
public:
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenToDtypeOp op,
PatternRewriter &rewriter) const override {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
nonBlocking)
return failure();
// The copy arg must be `False`.
if (!matchPattern(op.getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
return failure();
// The memory_format arg must be `none`.
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
return failure();
auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
auto resultType = dyn_cast<ValueTensorType>(op.getType());
if (!inputType || !resultType || !inputType.hasDtype() ||
!resultType.hasDtype())
return failure();
auto inputDtype = inputType.getDtype();
auto resultDtype = resultType.getDtype();
SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> converted;
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
resultDtype)))
return rewriter.notifyMatchFailure(
op, "Unhandled attribute type encountered.");
SmallVector<Value> vals;
if (failed(materializeFolds(b, converted, vals)))
return failure();
Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
template <typename AtenViewLikeOp>
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
@ -828,7 +931,7 @@ public:
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();
if (resultTy.getSizes().size() == 0) {
if (resultTy.getSizes().empty()) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
@ -841,6 +944,48 @@ public:
};
} // namespace
namespace {
template <typename OpTy, typename ScalarOpTy>
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Check type
auto resultTy = cast<ValueTensorType>(op.getType());
if (resultTy.getSizes().size() > 1)
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
return rewriter.notifyMatchFailure(op, "not an int type");
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> selfFold;
if (failed(getListFromTensor(op.getSelf(), selfFold)))
return failure();
SmallVector<Value> selfVals;
if (failed(materializeFolds(b, selfFold, selfVals)))
return failure();
SmallVector<OpFoldResult> resultFolds;
for (uint64_t i = 0; i < selfVals.size(); i++) {
resultFolds.push_back(
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
}
SmallVector<Value> resultVals;
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();
if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
}
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
/// ------ Fold Patterns ------ ///
// These are shape-specific folding patterns
@ -915,6 +1060,11 @@ public:
auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "dynamic output shape");
if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, op.getType(), elements.front());
return success();
}
auto loc = op.getLoc();
SmallVector<Value> sizes;
@ -922,12 +1072,10 @@ public:
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), 1);
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
one);
sizes);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
@ -1031,6 +1179,24 @@ public:
};
} // namespace
namespace {
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
public:
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIntScalarOp op,
PatternRewriter &rewriter) const override {
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
if (!floatScalarOp)
return failure();
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
if (!sizeOp)
return failure();
rewriter.replaceOp(op, floatScalarOp.getA());
return success();
}
};
} // namespace
namespace {
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
public:
@ -1182,8 +1348,29 @@ public:
if (inputUnmatched == 1 && outputUnmatched > 1) {
Value dimVal =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
// try to convert a single dynamic size input to -1
int64_t dynCount = 0;
int64_t dynIdx = 0;
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
int64_t szeInt;
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
dynCount++;
dynIdx = i;
continue;
}
// if we have a -1 already, make dynCount invalid and break
if (szeInt == -1) {
dynCount = -1;
break;
}
}
// if only one size is dynamic, make it -1
if (dynCount == 1)
unflattenSizes[dynIdx] =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(), unflattenSizes);
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
@ -1227,6 +1414,18 @@ public:
namespace {
bool isItemForSliceOp(Operation *op) {
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
if (!itemOp)
return false;
for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner();
if (isa<AtenSliceTensorOp>(userOp))
return true;
}
return false;
}
bool isSourceOpForShapeScalarization(Operation *op) {
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) {
bool isAnchorOp(Operation *op) {
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
isPrimListOfInts(op);
isPrimListOfInts(op) || isItemForSliceOp(op);
}
// The argument to this function, op, is the use of some source op, srcOp. If
@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op,
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
patterns.getContext());
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
FoldAtenEqIntPattern>(patterns.getContext());
}
void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
@ -1303,10 +1502,12 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenTransposeIntPattern,
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
patterns.getContext());
}
@ -1314,6 +1515,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
@ -1321,6 +1523,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext());
}

View File

@ -105,9 +105,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v
// CHECK-LABEL: test_einsum_inner_prod
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK: %[[INT5:.+]] = torch.constant.int 5
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]

View File

@ -27,12 +27,8 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso
// CHECK-LABEL: @shape_as_tensor_dim
func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> {
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]]
// CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32>
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%dim = torch.constant.int 0
@ -43,6 +39,49 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt
return %select : !torch.vtensor<[],si32>
}
// -----
// CHECK-LABEL: @cast_int_int
func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64>
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%none = torch.constant.none
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64>
%item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int
%list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list<int>
return %select : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: @cast_int_float
func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float
// CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32>
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%none = torch.constant.none
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32>
%item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float
%item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int
%list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list<int>
return %select : !torch.vtensor<[],f32>
}
// -----
@ -89,14 +128,12 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?]
// CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[int12_1:.*]] = torch.constant.int 12
// CHECK: %[[int1_2:.*]] = torch.constant.int 1
// CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32>
// CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32>
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%float0.000000e00 = torch.constant.float 0.000000e+00