Convert Torch constant ops to std.constant

pull/234/head
Yi Zhang 2021-06-18 17:44:43 +00:00 committed by Sean Silva
parent 78d2cc0818
commit e6adecac83
2 changed files with 43 additions and 6 deletions

View File

@ -71,14 +71,14 @@ public:
} // namespace
namespace {
class ConvertValueTensorLiteralOp
: public OpConversionPattern<ValueTensorLiteralOp> {
template <typename OpTy>
class ConvertTorchConstantOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(ValueTensorLiteralOp op, ArrayRef<Value> operands,
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.valueAttr());
return success();
}
};
@ -112,7 +112,17 @@ public:
target.addIllegalOp<AtenGtIntOp>();
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
target.addIllegalOp<ValueTensorLiteralOp>();
patterns.add<ConvertValueTensorLiteralOp>(typeConverter, context);
patterns.add<ConvertTorchConstantOp<ValueTensorLiteralOp>>(typeConverter,
context);
target.addIllegalOp<ConstantBoolOp>();
patterns.add<ConvertTorchConstantOp<ConstantBoolOp>>(typeConverter,
context);
target.addIllegalOp<Torch::ConstantFloatOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
context);
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();

View File

@ -47,3 +47,30 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
// CHECK: %[[CST:.*]] = constant true
// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]]
// CHECK: return %[[BOOL]] : !torch.bool
func @torch.constant.bool() -> !torch.bool {
%true = torch.constant.bool true
return %true : !torch.bool
}
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]]
// CHECK: return %[[FLOAT]] : !torch.float
func @torch.constant.float() -> !torch.float {
%float = torch.constant.float 1.000000e+00
return %float : !torch.float
}
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
// CHECK: %[[CST:.*]] = constant 1 : i64
// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]]
// CHECK: return %[[INT]] : !torch.int
func @torch.constant.int() -> !torch.int {
%int1 = torch.constant.int 1
return %int1 : !torch.int
}