mirror of https://github.com/llvm/torch-mlir
Convert Torch constant ops to std.constant
parent
78d2cc0818
commit
e6adecac83
|
@ -71,14 +71,14 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertValueTensorLiteralOp
|
||||
: public OpConversionPattern<ValueTensorLiteralOp> {
|
||||
template <typename OpTy>
|
||||
class ConvertTorchConstantOp : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ValueTensorLiteralOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
|
||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.valueAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -112,7 +112,17 @@ public:
|
|||
target.addIllegalOp<AtenGtIntOp>();
|
||||
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValueTensorLiteralOp>();
|
||||
patterns.add<ConvertValueTensorLiteralOp>(typeConverter, context);
|
||||
patterns.add<ConvertTorchConstantOp<ValueTensorLiteralOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<ConstantBoolOp>();
|
||||
patterns.add<ConvertTorchConstantOp<ConstantBoolOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<Torch::ConstantFloatOp>();
|
||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||
context);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -47,3 +47,30 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
|||
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
|
||||
// CHECK: %[[CST:.*]] = constant true
|
||||
// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]]
|
||||
// CHECK: return %[[BOOL]] : !torch.bool
|
||||
func @torch.constant.bool() -> !torch.bool {
|
||||
%true = torch.constant.bool true
|
||||
return %true : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
|
||||
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
|
||||
// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]]
|
||||
// CHECK: return %[[FLOAT]] : !torch.float
|
||||
func @torch.constant.float() -> !torch.float {
|
||||
%float = torch.constant.float 1.000000e+00
|
||||
return %float : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i64
|
||||
// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]]
|
||||
// CHECK: return %[[INT]] : !torch.int
|
||||
func @torch.constant.int() -> !torch.int {
|
||||
%int1 = torch.constant.int 1
|
||||
return %int1 : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue