mirror of https://github.com/llvm/torch-mlir
Convert Torch constant ops to std.constant
parent
78d2cc0818
commit
e6adecac83
|
@ -71,14 +71,14 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertValueTensorLiteralOp
|
template <typename OpTy>
|
||||||
: public OpConversionPattern<ValueTensorLiteralOp> {
|
class ConvertTorchConstantOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(ValueTensorLiteralOp op, ArrayRef<Value> operands,
|
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
|
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.valueAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -112,7 +112,17 @@ public:
|
||||||
target.addIllegalOp<AtenGtIntOp>();
|
target.addIllegalOp<AtenGtIntOp>();
|
||||||
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
|
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
|
||||||
target.addIllegalOp<ValueTensorLiteralOp>();
|
target.addIllegalOp<ValueTensorLiteralOp>();
|
||||||
patterns.add<ConvertValueTensorLiteralOp>(typeConverter, context);
|
patterns.add<ConvertTorchConstantOp<ValueTensorLiteralOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<ConstantBoolOp>();
|
||||||
|
patterns.add<ConvertTorchConstantOp<ConstantBoolOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<Torch::ConstantFloatOp>();
|
||||||
|
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||||
|
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||||
|
context);
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -47,3 +47,30 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||||
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||||
return %0 : !torch.vtensor<[],f32>
|
return %0 : !torch.vtensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
|
||||||
|
// CHECK: %[[CST:.*]] = constant true
|
||||||
|
// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]]
|
||||||
|
// CHECK: return %[[BOOL]] : !torch.bool
|
||||||
|
func @torch.constant.bool() -> !torch.bool {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
return %true : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
|
||||||
|
// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]]
|
||||||
|
// CHECK: return %[[FLOAT]] : !torch.float
|
||||||
|
func @torch.constant.float() -> !torch.float {
|
||||||
|
%float = torch.constant.float 1.000000e+00
|
||||||
|
return %float : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1 : i64
|
||||||
|
// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]]
|
||||||
|
// CHECK: return %[[INT]] : !torch.int
|
||||||
|
func @torch.constant.int() -> !torch.int {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
return %int1 : !torch.int
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue