mirror of https://github.com/llvm/torch-mlir
This reverts commit c935795086
.
pull/1217/head
snapshot-20220817.567
parent
85f383ce0b
commit
9be8997536
|
@ -5319,10 +5319,9 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
||||||
NoSideEffect,
|
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
ReadOnly,
|
ReadOnly
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|
|
@ -71,25 +71,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp
|
|
||||||
namespace {
|
|
||||||
template <typename AtenOpT>
|
|
||||||
class ConvertAtenUnaryConvertOp: public OpConversionPattern<AtenOpT> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
||||||
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
|
||||||
op,
|
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
||||||
op.getType()),
|
|
||||||
adaptor.self());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// aten.ones & aten.zeros
|
// aten.ones & aten.zeros
|
||||||
// Ref: Error checking based on the Torch to TOSA lowering
|
// Ref: Error checking based on the Torch to TOSA lowering
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -329,9 +310,6 @@ public:
|
||||||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
||||||
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), mhlo::ComparisonDirection::GT);
|
op->getContext(), mhlo::ComparisonDirection::GT);
|
||||||
} else if (std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
|
||||||
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
|
||||||
op->getContext(), mhlo::ComparisonDirection::GE);
|
|
||||||
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
||||||
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
||||||
|
@ -1005,72 +983,6 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// AtenSizeIntOp
|
|
||||||
namespace {
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|
||||||
AtenSizeIntOp op,
|
|
||||||
OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
// Not a tensor type.
|
|
||||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
|
||||||
if (!selfType)
|
|
||||||
return op.emitError("Only tensor types are currently supported");
|
|
||||||
auto dim = rewriter.create<arith::IndexCastOp>(
|
|
||||||
op.getLoc(), rewriter.getIndexType(), adaptor.dim());
|
|
||||||
auto dimSize = rewriter.create<tensor::DimOp>(
|
|
||||||
op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
|
||||||
op, getTypeConverter()->convertType(op.getType()), dimSize);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// ValsemVariantAtenUniformOp
|
|
||||||
namespace {
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
|
|
||||||
ValsemVariantAtenUniformOp op,
|
|
||||||
OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
auto inputTy = adaptor.self().getType().template cast<RankedTensorType>();
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
if (!inputTy) {
|
|
||||||
op.emitError("input should be ranked tensor type.");
|
|
||||||
}
|
|
||||||
auto definingOp = op.self().getDefiningOp();
|
|
||||||
auto shape = definingOp->getOperand(0);
|
|
||||||
SmallVector<Value, 4> dimSizes;
|
|
||||||
getListConstructElements(shape, dimSizes);
|
|
||||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
|
|
||||||
dSize = rewriter.create<torch::TorchConversion::ToI64Op>(loc, dSize).getResult();
|
|
||||||
return dSize;
|
|
||||||
});
|
|
||||||
|
|
||||||
auto mhloShape =
|
|
||||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), dimSizes);
|
|
||||||
|
|
||||||
double fromDoubleValue, toDoubleValue;
|
|
||||||
if (!matchPattern(op.from(), m_TorchConstantFloat(&fromDoubleValue))) {
|
|
||||||
op.emitError("operand #1 should be scalar");
|
|
||||||
}
|
|
||||||
if (!matchPattern(op.to(), m_TorchConstantFloat(&toDoubleValue))) {
|
|
||||||
op.emitError("operand #2 should be scalar");
|
|
||||||
}
|
|
||||||
Value fromTensor = rewriter.create<mhlo::ConstantOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
rewriter.getFloatAttr(inputTy.getElementType(), fromDoubleValue));
|
|
||||||
Value toTensor = rewriter.create<mhlo::ConstantOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue));
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::RngOp>(
|
|
||||||
op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -1096,15 +1008,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
|
||||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||||
|
|
||||||
#define INSERT_UNARY_CONVERT_PATTERN(AtenOp) \
|
|
||||||
target.addIllegalOp<AtenOp>(); \
|
|
||||||
patterns.add<ConvertAtenUnaryConvertOp<AtenOp>>(typeConverter, \
|
|
||||||
context);
|
|
||||||
INSERT_UNARY_CONVERT_PATTERN(AtenContiguousOp);
|
|
||||||
INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeOp);
|
|
||||||
INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp);
|
|
||||||
#undef INSERT_UNARY_CONVERT_PATTERN
|
|
||||||
|
|
||||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||||
|
@ -1139,7 +1042,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp);
|
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
|
||||||
|
@ -1165,7 +1067,5 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
|
|
||||||
INSERT_ATENOP_PATTERN(ValsemVariantAtenUniformOp);
|
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -1155,47 +1155,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class DecomposeAtenNativeDropoutOp : public OpRewritePattern<AtenNativeDropoutOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value input = op.input();
|
|
||||||
Value prob = op.p();
|
|
||||||
bool train = false;
|
|
||||||
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "train must be a boolean constant");
|
|
||||||
|
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
||||||
if (!train) {
|
|
||||||
// TODO(yancey.yx): supports inference mode
|
|
||||||
return op.emitError(
|
|
||||||
"native_dropout does not support argument train is false");
|
|
||||||
}
|
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "only support floating type input for training mode");
|
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
||||||
Value floatOne =
|
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
||||||
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
|
||||||
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
|
||||||
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
|
||||||
Value maskedInput =
|
|
||||||
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
|
||||||
Value output =
|
|
||||||
rewriter.create<AtenMulScalarOp>(loc, inputType, maskedInput, oneMinusP);
|
|
||||||
Value one =
|
|
||||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
||||||
boolMask = rewriter.create<AtenGeScalarOp>(
|
|
||||||
loc, op.getResult(1).getType(), boolMask, one);
|
|
||||||
rewriter.replaceOp(op, {output, boolMask});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
// Decompose aten.var into: aten.var.dim op.
|
// Decompose aten.var into: aten.var.dim op.
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
||||||
|
@ -2635,8 +2594,6 @@ class DecomposeComplexOpsPass
|
||||||
patterns.add<DecomposeAten_ToCopyOp>(context);
|
patterns.add<DecomposeAten_ToCopyOp>(context);
|
||||||
target.addIllegalOp<Aten_ToCopyOp>();
|
target.addIllegalOp<Aten_ToCopyOp>();
|
||||||
patterns.add<DecomposeAtenDropoutOp>(context);
|
patterns.add<DecomposeAtenDropoutOp>(context);
|
||||||
patterns.add<DecomposeAtenNativeDropoutOp>(context);
|
|
||||||
target.addIllegalOp<AtenNativeDropoutOp>();
|
|
||||||
target.addIllegalOp<AtenDropoutOp>();
|
target.addIllegalOp<AtenDropoutOp>();
|
||||||
target.addIllegalOp<AtenNewEmptyOp>();
|
target.addIllegalOp<AtenNewEmptyOp>();
|
||||||
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
||||||
|
|
|
@ -139,8 +139,6 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.native_dropout.train(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: f64) -> (tensor<?x?xf32>, tensor<?x?xi1>) {
|
|
||||||
// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
|
||||||
// CHECK: %[[CST_0:.*]] = arith.constant 1 : index
|
|
||||||
// CHECK: %[[CST_1:.*]] = arith.constant 0 : index
|
|
||||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
|
||||||
// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f64>
|
|
||||||
// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64
|
|
||||||
// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64
|
|
||||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64>
|
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf64>) -> tensor<f64>
|
|
||||||
// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor<?x?xf32>) -> tensor<?x?xf64>
|
|
||||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor<?x?xf64>
|
|
||||||
// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
|
||||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor<?x?xf64>
|
|
||||||
// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
|
||||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>} : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<?x?xf64>
|
|
||||||
// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor<?x?xf64> -> tensor<2xindex>
|
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f64>, tensor<2xindex>) -> tensor<?x?xf64>
|
|
||||||
// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor<?x?xf64>, tensor<?x?xf64>) -> tensor<?x?xi1>
|
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor<?x?xi1>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor<?x?xf32> -> tensor<2xindex>
|
|
||||||
// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor<?x?xf32> -> tensor<2xindex>
|
|
||||||
// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex>
|
|
||||||
// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor<?x?xf32>) {
|
|
||||||
// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
|
|
||||||
// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor<?x?xf32>
|
|
||||||
// CHECK: shape.assuming_yield %[[T19]] : tensor<?x?xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32>
|
|
||||||
// CHECK: %[[T21:.*]] = mhlo.reshape %[[T20]] : (tensor<1xf32>) -> tensor<f32>
|
|
||||||
// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor<?x?xf32> -> tensor<2xindex>
|
|
||||||
// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
|
||||||
// CHECK: return %[[T24]], %[[T26]] : tensor<?x?xf32>, tensor<?x?xi1>
|
|
||||||
func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) {
|
|
||||||
%bool_true = torch.constant.bool true
|
|
||||||
%result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
|
|
||||||
return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
|
|
||||||
}
|
|
|
@ -360,17 +360,9 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
|
||||||
// CHECK: %[[INTneg1:.*]] = torch.constant.int -1
|
// CHECK: %[[INTneg1:.*]] = torch.constant.int -1
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]]
|
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]]
|
// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
|
||||||
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index
|
// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
|
||||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32>
|
|
||||||
// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
|
||||||
// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]]
|
|
||||||
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index
|
|
||||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32>
|
|
||||||
// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]]
|
|
||||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
|
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
|
||||||
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]
|
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]
|
||||||
|
|
Loading…
Reference in New Issue