mirror of https://github.com/llvm/torch-mlir
Add Scalarization Patterns for `AtenToDtypeOp`, `AtenNegOp`, `AtenRemainderTensorOp` (#3861)
1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith. 2. adds a scalarization pattern for `aten.neg` and `aten.remainder.Tensor` ops. 3. improves folding of `aten.mul.int` 4. adds a scalarization pattern for `aten.to.dtype` which relies on scalar cast ops and basic C++ casting between `double` and `int64_t`. 5. improves rank-0 case handling for `FoldAtenSplatPattern` 6. removes a bug with `aten.unflatten.int` decomposition incorrectly generating a constant size int from a dynamic shape. 7. simplifies the dim list for `aten.unflatten.int` ops generated from the `aten.view` canonicalization in scalarize shapes. All of these changes were necessary to unblock <https://github.com/iree-org/iree/issues/18899>.main
parent
889a836b3d
commit
cd38ecf6c2
|
@ -82,6 +82,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenNegIntOp op,
|
||||
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value a = adaptor.getA();
|
||||
rewriter.replaceOpWithNewOp<arith::SubIOp>(
|
||||
op,
|
||||
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
|
||||
/*bitwidth=*/64),
|
||||
a);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOp, typename UnaryOp>
|
||||
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
|
||||
|
@ -465,11 +484,14 @@ public:
|
|||
|
||||
target.addIllegalOp<AtenAddOp>();
|
||||
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenNegIntOp>();
|
||||
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
|
||||
AtenMulIntOp>();
|
||||
AtenMulIntOp, AtenRemainderIntOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
|
||||
|
|
|
@ -4068,6 +4068,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
|||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
if (lConstant && lhs == 1)
|
||||
return getOperand(1);
|
||||
if (rConstant && rhs == 1)
|
||||
return getOperand(0);
|
||||
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
|
||||
return getI64IntegerAttr(getContext(), 0);
|
||||
if (lConstant && rConstant)
|
||||
|
|
|
@ -4587,6 +4587,11 @@ public:
|
|||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
|
||||
if (inputShape[dimInt] == Torch::kUnknownSize &&
|
||||
llvm::count(sizesInts, -1) > 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: dynamic unflatten dim with an inferred size.");
|
||||
|
||||
SmallVector<Value> sizesTorchInt;
|
||||
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -714,7 +714,7 @@ public:
|
|||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
// Rank 0 item op prop
|
||||
if (selfTy.getSizes().size() == 0) {
|
||||
if (selfTy.getSizes().empty()) {
|
||||
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
|
||||
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
|
||||
if (!squeezeDim && !numToTensor)
|
||||
|
@ -746,6 +746,109 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
|
||||
SmallVector<OpFoldResult> &converted,
|
||||
SmallVector<OpFoldResult> &elements,
|
||||
Type inputDtype, Type resultDtype) {
|
||||
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
|
||||
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
|
||||
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
|
||||
return failure();
|
||||
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
|
||||
return failure();
|
||||
|
||||
// if dtypes are both int or both float, no conversion needed
|
||||
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
|
||||
converted = elements;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (resultIsInt) {
|
||||
for (auto &e : elements) {
|
||||
auto eValue = dyn_cast<Value>(e);
|
||||
if (eValue) {
|
||||
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
|
||||
continue;
|
||||
}
|
||||
auto eAttr = dyn_cast<Attribute>(e);
|
||||
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
|
||||
if (!eFloatAttr)
|
||||
return failure();
|
||||
|
||||
converted.push_back(IntegerAttr::get(
|
||||
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// result is float
|
||||
for (auto &e : elements) {
|
||||
auto eValue = dyn_cast<Value>(e);
|
||||
if (eValue) {
|
||||
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
|
||||
continue;
|
||||
}
|
||||
auto eAttr = dyn_cast<Attribute>(e);
|
||||
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
|
||||
if (!eIntAttr)
|
||||
return failure();
|
||||
|
||||
auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
|
||||
: eIntAttr.getValue().getZExtValue();
|
||||
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenToDtypeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool nonBlocking, copyArg;
|
||||
// The non_blocking arg must be `False`.
|
||||
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
|
||||
nonBlocking)
|
||||
return failure();
|
||||
// The copy arg must be `False`.
|
||||
if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
||||
return failure();
|
||||
// The memory_format arg must be `none`.
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
|
||||
auto resultType = dyn_cast<ValueTensorType>(op.getType());
|
||||
if (!inputType || !resultType || !inputType.hasDtype() ||
|
||||
!resultType.hasDtype())
|
||||
return failure();
|
||||
auto inputDtype = inputType.getDtype();
|
||||
auto resultDtype = resultType.getDtype();
|
||||
|
||||
SmallVector<OpFoldResult> elements;
|
||||
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||
return failure();
|
||||
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
SmallVector<OpFoldResult> converted;
|
||||
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
|
||||
resultDtype)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unhandled attribute type encountered.");
|
||||
|
||||
SmallVector<Value> vals;
|
||||
if (failed(materializeFolds(b, converted, vals)))
|
||||
return failure();
|
||||
|
||||
Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenViewLikeOp>
|
||||
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
|
||||
|
@ -828,7 +931,7 @@ public:
|
|||
if (failed(materializeFolds(b, resultFolds, resultVals)))
|
||||
return failure();
|
||||
|
||||
if (resultTy.getSizes().size() == 0) {
|
||||
if (resultTy.getSizes().empty()) {
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
|
||||
op, resultTy, resultVals.front());
|
||||
return success();
|
||||
|
@ -841,6 +944,48 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename OpTy, typename ScalarOpTy>
|
||||
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check type
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
if (resultTy.getSizes().size() > 1)
|
||||
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
|
||||
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "not an int type");
|
||||
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
SmallVector<OpFoldResult> selfFold;
|
||||
if (failed(getListFromTensor(op.getSelf(), selfFold)))
|
||||
return failure();
|
||||
SmallVector<Value> selfVals;
|
||||
if (failed(materializeFolds(b, selfFold, selfVals)))
|
||||
return failure();
|
||||
SmallVector<OpFoldResult> resultFolds;
|
||||
for (uint64_t i = 0; i < selfVals.size(); i++) {
|
||||
resultFolds.push_back(
|
||||
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
|
||||
}
|
||||
SmallVector<Value> resultVals;
|
||||
if (failed(materializeFolds(b, resultFolds, resultVals)))
|
||||
return failure();
|
||||
|
||||
if (resultTy.getSizes().size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
|
||||
op, resultTy, resultVals.front());
|
||||
return success();
|
||||
}
|
||||
|
||||
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
/// ------ Fold Patterns ------ ///
|
||||
// These are shape-specific folding patterns
|
||||
|
||||
|
@ -915,6 +1060,11 @@ public:
|
|||
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
|
||||
return rewriter.notifyMatchFailure(op, "dynamic output shape");
|
||||
if (resultTy.getSizes().size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
|
||||
op, op.getType(), elements.front());
|
||||
return success();
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value> sizes;
|
||||
|
@ -922,12 +1072,10 @@ public:
|
|||
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(size)));
|
||||
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), 1);
|
||||
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
one);
|
||||
sizes);
|
||||
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
|
@ -1031,6 +1179,24 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
|
||||
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIntScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
|
||||
if (!floatScalarOp)
|
||||
return failure();
|
||||
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
|
||||
if (!sizeOp)
|
||||
return failure();
|
||||
rewriter.replaceOp(op, floatScalarOp.getA());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
namespace {
|
||||
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
|
||||
public:
|
||||
|
@ -1182,8 +1348,29 @@ public:
|
|||
if (inputUnmatched == 1 && outputUnmatched > 1) {
|
||||
Value dimVal =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
|
||||
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
|
||||
viewSizes.end() - rightMatchEnd);
|
||||
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
|
||||
viewSizes.end() - rightMatchEnd);
|
||||
// try to convert a single dynamic size input to -1
|
||||
int64_t dynCount = 0;
|
||||
int64_t dynIdx = 0;
|
||||
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
|
||||
int64_t szeInt;
|
||||
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
|
||||
dynCount++;
|
||||
dynIdx = i;
|
||||
continue;
|
||||
}
|
||||
// if we have a -1 already, make dynCount invalid and break
|
||||
if (szeInt == -1) {
|
||||
dynCount = -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// if only one size is dynamic, make it -1
|
||||
if (dynCount == 1)
|
||||
unflattenSizes[dynIdx] =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
||||
|
||||
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), op.getSize().getType(), unflattenSizes);
|
||||
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
|
||||
|
@ -1227,6 +1414,18 @@ public:
|
|||
|
||||
namespace {
|
||||
|
||||
bool isItemForSliceOp(Operation *op) {
|
||||
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
|
||||
if (!itemOp)
|
||||
return false;
|
||||
for (OpOperand &use : op->getUses()) {
|
||||
Operation *userOp = use.getOwner();
|
||||
if (isa<AtenSliceTensorOp>(userOp))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isSourceOpForShapeScalarization(Operation *op) {
|
||||
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
|
||||
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
|
||||
|
@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) {
|
|||
|
||||
bool isAnchorOp(Operation *op) {
|
||||
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
|
||||
isPrimListOfInts(op);
|
||||
isPrimListOfInts(op) || isItemForSliceOp(op);
|
||||
}
|
||||
|
||||
// The argument to this function, op, is the use of some source op, srcOp. If
|
||||
|
@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op,
|
|||
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
|
||||
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
|
||||
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
|
||||
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
|
||||
patterns.getContext());
|
||||
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
|
||||
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
|
||||
FoldAtenEqIntPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
|
||||
|
@ -1303,10 +1502,12 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
|||
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
||||
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
|
||||
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
|
||||
PropagateAtenTransposeIntPattern,
|
||||
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
|
||||
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
|
||||
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
|
||||
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
|
||||
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
|
||||
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
|
||||
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
@ -1314,6 +1515,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
|||
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
|
||||
RemoveUnusedPattern<Torch::AtenEqIntOp>,
|
||||
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
|
||||
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
|
||||
RemoveUnusedPattern<Torch::AtenFullOp>,
|
||||
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
|
||||
|
@ -1321,6 +1523,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
|||
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
|
||||
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
|
||||
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -105,9 +105,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v
|
|||
|
||||
// CHECK-LABEL: test_einsum_inner_prod
|
||||
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK: %[[INT5:.+]] = torch.constant.int 5
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
|
||||
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||
|
|
|
@ -27,12 +27,8 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso
|
|||
// CHECK-LABEL: @shape_as_tensor_dim
|
||||
func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> {
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]]
|
||||
// CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||
// CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||
// CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32>
|
||||
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||
%dim = torch.constant.int 0
|
||||
|
@ -43,6 +39,49 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt
|
|||
return %select : !torch.vtensor<[],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cast_int_int
|
||||
func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64>
|
||||
%int4 = torch.constant.int 4
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||
%cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64>
|
||||
%dim = torch.constant.int 0
|
||||
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
|
||||
%select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64>
|
||||
%item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int
|
||||
%list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list<int>
|
||||
return %select : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cast_int_float
|
||||
func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float
|
||||
// CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32>
|
||||
%int6 = torch.constant.int 6
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||
%cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32>
|
||||
%dim = torch.constant.int 0
|
||||
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
|
||||
%select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32>
|
||||
%item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float
|
||||
%item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int
|
||||
%list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list<int>
|
||||
return %select : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -89,14 +128,12 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?]
|
|||
// CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[int12_1:.*]] = torch.constant.int 12
|
||||
// CHECK: %[[int1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32>
|
||||
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%1 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
|
|
Loading…
Reference in New Issue