mirror of https://github.com/llvm/torch-mlir
Add basicpy.numeric_constant op.
* Going through TODOs on the PyTorch side, this is a big cause of them (not being able to have constants for signed/unsigned). * Added complex while in here since we're at the phase where it is better to just have things complete than partially done.pull/128/head
parent
bea0af419d
commit
3937dd14cb
|
@ -17,7 +17,8 @@ using namespace torch_mlir;
|
||||||
static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
||||||
MlirAttribute value) {
|
MlirAttribute value) {
|
||||||
OperationStateHolder s("std.constant", loc);
|
OperationStateHolder s("std.constant", loc);
|
||||||
MlirNamedAttribute valueAttr = mlirNamedAttributeGet(toMlirStringRef("value"), value);
|
MlirNamedAttribute valueAttr =
|
||||||
|
mlirNamedAttributeGet(toMlirStringRef("value"), value);
|
||||||
mlirOperationStateAddResults(s, 1, &type);
|
mlirOperationStateAddResults(s, 1, &type);
|
||||||
mlirOperationStateAddAttributes(s, 1, &valueAttr);
|
mlirOperationStateAddAttributes(s, 1, &valueAttr);
|
||||||
return s.createOperation();
|
return s.createOperation();
|
||||||
|
@ -44,12 +45,15 @@ void KernelCallBuilder::addSchemaAttrs() {
|
||||||
// sigIsVarret
|
// sigIsVarret
|
||||||
// sigIsMutable
|
// sigIsMutable
|
||||||
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
||||||
attrs.push_back(mlirNamedAttributeGet(
|
attrs.push_back(
|
||||||
toMlirStringRef("sigIsMutable"), mlirBoolAttrGet(context, schema.is_mutable())));
|
mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"),
|
||||||
attrs.push_back(mlirNamedAttributeGet(
|
mlirBoolAttrGet(context, schema.is_mutable())));
|
||||||
toMlirStringRef("sigIsVararg"), mlirBoolAttrGet(context, schema.is_vararg())));
|
attrs.push_back(
|
||||||
attrs.push_back(mlirNamedAttributeGet(
|
mlirNamedAttributeGet(toMlirStringRef("sigIsVararg"),
|
||||||
toMlirStringRef("sigIsVarret"), mlirBoolAttrGet(context, schema.is_varret())));
|
mlirBoolAttrGet(context, schema.is_vararg())));
|
||||||
|
attrs.push_back(
|
||||||
|
mlirNamedAttributeGet(toMlirStringRef("sigIsVarret"),
|
||||||
|
mlirBoolAttrGet(context, schema.is_varret())));
|
||||||
|
|
||||||
// Arg types.
|
// Arg types.
|
||||||
llvm::SmallVector<MlirAttribute, 4> args;
|
llvm::SmallVector<MlirAttribute, 4> args;
|
||||||
|
@ -58,7 +62,8 @@ void KernelCallBuilder::addSchemaAttrs() {
|
||||||
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||||
}
|
}
|
||||||
attrs.push_back(mlirNamedAttributeGet(
|
attrs.push_back(mlirNamedAttributeGet(
|
||||||
toMlirStringRef("sigArgTypes"), mlirArrayAttrGet(context, args.size(), args.data())));
|
toMlirStringRef("sigArgTypes"),
|
||||||
|
mlirArrayAttrGet(context, args.size(), args.data())));
|
||||||
|
|
||||||
// Return types.
|
// Return types.
|
||||||
llvm::SmallVector<MlirAttribute, 4> returns;
|
llvm::SmallVector<MlirAttribute, 4> returns;
|
||||||
|
@ -203,14 +208,17 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||||
// (this is fragile and reveals details that are not guaranteed).
|
// (this is fragile and reveals details that are not guaranteed).
|
||||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||||
|
funcAttrs.push_back(
|
||||||
|
mlirNamedAttributeGet(toMlirStringRef("type"),
|
||||||
|
mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||||
|
context, inputTypes.size(), inputTypes.data(),
|
||||||
|
/*numResults=*/0, /*results=*/nullptr))));
|
||||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||||
toMlirStringRef("type"), mlirTypeAttrGet(mlirFunctionTypeGet(
|
toMlirStringRef("sym_name"),
|
||||||
context, inputTypes.size(), inputTypes.data(),
|
mlirStringAttrGet(context, name.size(), name.data())));
|
||||||
/*numResults=*/0, /*results=*/nullptr))));
|
|
||||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
|
||||||
toMlirStringRef("sym_name"), mlirStringAttrGet(context, name.size(), name.data())));
|
|
||||||
|
|
||||||
MlirOperationState state = mlirOperationStateGet(toMlirStringRef("func"), location);
|
MlirOperationState state =
|
||||||
|
mlirOperationStateGet(toMlirStringRef("func"), location);
|
||||||
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
||||||
{
|
{
|
||||||
// Don't access these once ownership transferred.
|
// Don't access these once ownership transferred.
|
||||||
|
@ -234,7 +242,8 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||||
void FuncBuilder::rewriteFuncReturnTypes(
|
void FuncBuilder::rewriteFuncReturnTypes(
|
||||||
llvm::SmallVectorImpl<MlirType> &resultTypes) {
|
llvm::SmallVectorImpl<MlirType> &resultTypes) {
|
||||||
// Get inputs from current function type.
|
// Get inputs from current function type.
|
||||||
MlirAttribute funcTypeAttr = mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
MlirAttribute funcTypeAttr =
|
||||||
|
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
||||||
assert(!mlirAttributeIsNull(funcTypeAttr) &&
|
assert(!mlirAttributeIsNull(funcTypeAttr) &&
|
||||||
"function missing 'type' attribute");
|
"function missing 'type' attribute");
|
||||||
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
||||||
|
@ -250,7 +259,8 @@ void FuncBuilder::rewriteFuncReturnTypes(
|
||||||
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||||
resultTypes.size(), resultTypes.data());
|
resultTypes.size(), resultTypes.data());
|
||||||
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
|
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
|
||||||
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"), newFuncTypeAttr);
|
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
|
||||||
|
newFuncTypeAttr);
|
||||||
(void)newFuncTypeAttr;
|
(void)newFuncTypeAttr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,7 @@ namespace torch_mlir {
|
||||||
class OperationStateHolder {
|
class OperationStateHolder {
|
||||||
public:
|
public:
|
||||||
OperationStateHolder(const char *name, MlirLocation loc)
|
OperationStateHolder(const char *name, MlirLocation loc)
|
||||||
: state(
|
: state(mlirOperationStateGet(toMlirStringRef(name), loc)) {}
|
||||||
mlirOperationStateGet(toMlirStringRef(name), loc)) {}
|
|
||||||
OperationStateHolder(const OperationStateHolder &) = delete;
|
OperationStateHolder(const OperationStateHolder &) = delete;
|
||||||
OperationStateHolder(OperationStateHolder &&other) = delete;
|
OperationStateHolder(OperationStateHolder &&other) = delete;
|
||||||
~OperationStateHolder() {
|
~OperationStateHolder() {
|
||||||
|
|
|
@ -22,6 +22,7 @@ def Basicpy_Dialect : Dialect {
|
||||||
Core types and ops
|
Core types and ops
|
||||||
}];
|
}];
|
||||||
let cppNamespace = "::mlir::NPCOMP::Basicpy";
|
let cppNamespace = "::mlir::NPCOMP::Basicpy";
|
||||||
|
let hasConstantMaterializer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -30,8 +31,9 @@ def Basicpy_Dialect : Dialect {
|
||||||
|
|
||||||
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
class Basicpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<Basicpy_Dialect, mnemonic, traits> {
|
Op<Basicpy_Dialect, mnemonic, traits> {
|
||||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
let parser = [{ return ::parse$cppClass(parser, &result); }];
|
||||||
let printer = [{ return print$cppClass(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -138,7 +140,7 @@ def Basicpy_DictType : DialectType<Basicpy_Dialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type/attribute predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Basicpy_SingletonType : AnyTypeOf<[
|
def Basicpy_SingletonType : AnyTypeOf<[
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/FunctionSupport.h"
|
#include "mlir/IR/FunctionSupport.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/CallInterfaces.h"
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||||||
|
|
|
@ -13,6 +13,7 @@ include "BasicpyDialect.td"
|
||||||
include "mlir/Interfaces/CallInterfaces.td"
|
include "mlir/Interfaces/CallInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -93,8 +94,54 @@ def CompareOperationAttr : StrEnumAttr<
|
||||||
// Constant and constructor operations
|
// Constant and constructor operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
|
||||||
|
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||||
|
let summary = "A constant from the Python3 numeric type hierarchy";
|
||||||
|
let description = [{
|
||||||
|
Basicpy re-uses core MLIR types to represent the Python3 numeric type
|
||||||
|
hierarchy with the following mappings:
|
||||||
|
|
||||||
|
* Python3 `int` : In python, this type is signed, arbitrary precision but
|
||||||
|
in typical realizations, it maps to an MLIR `IntegerType` of a fixed
|
||||||
|
bit-width (typically si64 if no further information is known). In the
|
||||||
|
future, there may be a real `Basicpy::IntType` that retains the true
|
||||||
|
arbitrary precision nature, but this is deemed an enhancement that
|
||||||
|
does not obviate the need to infer physical, sized types for many
|
||||||
|
real-world cases. As such, the Basicpy numeric type hierarchy will
|
||||||
|
always include physical `IntegerType`, if only to enable progressive
|
||||||
|
lowering and interop with cases where the precise type is known.
|
||||||
|
* Python3 `float` : This is allowed to map to any legal floating point
|
||||||
|
type on the physical machine and is usually represented as a double (f64).
|
||||||
|
In MLIR, any `FloatType` is allowed, which facilitates progressive
|
||||||
|
lowering and interop with cases where a more precise type is known.
|
||||||
|
* Python3 `complex` : Maps to an MLIR `ComplexType` with a `FloatType`
|
||||||
|
elementType (note: in Python, complex numbers are always defined with
|
||||||
|
floating point components).
|
||||||
|
* `bool` : See `bool_constant` for a constant (i1) -> !basicpy.BoolType
|
||||||
|
constant. This constant op is not used for representing such bool
|
||||||
|
values, even though from the Python perspective, bool is part of the
|
||||||
|
numeric hierarchy (the distinction is really only necessary during
|
||||||
|
promotion).
|
||||||
|
|
||||||
|
### Integer Signedness
|
||||||
|
|
||||||
|
All `int` values in Python are signed. However, there exist special cases
|
||||||
|
where libraries (i.e. struct packing and numpy arrays) interoperate with
|
||||||
|
unsigned values. As such, when mapping to MLIR, Python integer types
|
||||||
|
are represented as either signed or unsigned `IntegerType` types and can
|
||||||
|
be lowered to signless integers as appropriate (typically during realization
|
||||||
|
of arithmetic expressions where the choice is meaningful). Since it is not
|
||||||
|
known at the outset when in lowering this information is safe to discard
|
||||||
|
this `numeric_constant` op accepts any signedness.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyAttr:$value);
|
||||||
|
let results = (outs AnyType);
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
|
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
|
||||||
ConstantLike, NoSideEffect]> {
|
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||||
let summary = "A boolean constant";
|
let summary = "A boolean constant";
|
||||||
let description = [{
|
let description = [{
|
||||||
A constant of type !basicpy.BoolType that can take either an i1 value
|
A constant of type !basicpy.BoolType that can take either an i1 value
|
||||||
|
@ -173,7 +220,7 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
|
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
|
||||||
ConstantLike, NoSideEffect]> {
|
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||||
let summary = "Constant bytes value";
|
let summary = "Constant bytes value";
|
||||||
let description = [{
|
let description = [{
|
||||||
A bytes value of BytesType. The value is represented by a StringAttr.
|
A bytes value of BytesType. The value is represented by a StringAttr.
|
||||||
|
@ -204,7 +251,7 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
|
def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
|
||||||
ConstantLike, NoSideEffect]> {
|
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||||
let summary = "Constant string value";
|
let summary = "Constant string value";
|
||||||
let description = [{
|
let description = [{
|
||||||
A string value of StrType. The value is represented by a StringAttr
|
A string value of StrType. The value is represented by a StringAttr
|
||||||
|
@ -224,7 +271,7 @@ def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
|
||||||
// Casting and coercion operations
|
// Casting and coercion operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Basicpy_AsPredicateValueOp : Basicpy_Op<"as_predicate_value",
|
def Basicpy_AsI1Op : Basicpy_Op<"as_i1",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "Evaluates an input to an i1 predicate value";
|
let summary = "Evaluates an input to an i1 predicate value";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -355,7 +402,6 @@ def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> {
|
||||||
StrArrayAttr:$arg_names);
|
StrArrayAttr:$arg_names);
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
|
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
|
||||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilderDAG<(ins)>,
|
OpBuilderDAG<(ins)>,
|
||||||
|
@ -427,8 +473,6 @@ def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
|
||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilderDAG<(ins)>,
|
OpBuilderDAG<(ins)>,
|
||||||
|
|
|
@ -215,11 +215,11 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts the as_predicate_value op for numeric types.
|
// Converts the as_i1 op for numeric types.
|
||||||
class NumericToPredicateValue : public OpRewritePattern<Basicpy::AsPredicateValueOp> {
|
class NumericToI1 : public OpRewritePattern<Basicpy::AsI1Op> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(Basicpy::AsPredicateValueOp op,
|
LogicalResult matchAndRewrite(Basicpy::AsI1Op op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto operandType = op.operand().getType();
|
auto operandType = op.operand().getType();
|
||||||
|
@ -245,5 +245,5 @@ void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||||
patterns.insert<NumericBinaryExpr>(context);
|
patterns.insert<NumericBinaryExpr>(context);
|
||||||
patterns.insert<NumericCompare>(context);
|
patterns.insert<NumericCompare>(context);
|
||||||
patterns.insert<NumericToPredicateValue>(context);
|
patterns.insert<NumericToI1>(context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,37 @@ void BasicpyDialect::initialize() {
|
||||||
allowUnknownOperations();
|
allowUnknownOperations();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Operation *BasicpyDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
Attribute value, Type type,
|
||||||
|
Location loc) {
|
||||||
|
// NumericConstantOp.
|
||||||
|
// Supports IntegerType (any signedness), FloatType and ComplexType.
|
||||||
|
if (type.isa<IntegerType>() || type.isa<FloatType>() ||
|
||||||
|
type.isa<ComplexType>())
|
||||||
|
return builder.create<NumericConstantOp>(loc, type, value);
|
||||||
|
|
||||||
|
// Bool (i1 -> !basicpy.BoolType).
|
||||||
|
if (type.isa<Basicpy::BoolType>()) {
|
||||||
|
auto i1Value = value.dyn_cast<IntegerAttr>();
|
||||||
|
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
|
||||||
|
return builder.create<BoolConstantOp>(loc, type, i1Value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes.
|
||||||
|
if (type.isa<Basicpy::BytesType>()) {
|
||||||
|
if (auto strValue = value.dyn_cast<StringAttr>())
|
||||||
|
return builder.create<BytesConstantOp>(loc, type, strValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Str.
|
||||||
|
if (type.isa<Basicpy::StrType>()) {
|
||||||
|
if (auto strValue = value.dyn_cast<StringAttr>())
|
||||||
|
return builder.create<StrConstantOp>(loc, type, strValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
Type BasicpyDialect::parseType(DialectAsmParser &parser) const {
|
||||||
StringRef keyword;
|
StringRef keyword;
|
||||||
if (parser.parseKeyword(&keyword))
|
if (parser.parseKeyword(&keyword))
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOpsEnums.cpp.inc"
|
||||||
|
|
||||||
namespace mlir {
|
using namespace mlir;
|
||||||
namespace NPCOMP {
|
using namespace mlir::NPCOMP::Basicpy;
|
||||||
namespace Basicpy {
|
|
||||||
|
// Fallback verifier for ops that don't have a dedicated one.
|
||||||
|
template <typename T> static LogicalResult verify(T op) { return success(); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BoolConstantOp
|
// BoolConstantOp
|
||||||
|
@ -28,6 +29,11 @@ OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return valueAttr();
|
return valueAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BoolConstantOp::getAsmResultNames(
|
||||||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
|
setNameFn(getResult(), "bool");
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BytesConstantOp
|
// BytesConstantOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -36,6 +42,110 @@ OpFoldResult BytesConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return valueAttr();
|
return valueAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BytesConstantOp::getAsmResultNames(
|
||||||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
|
setNameFn(getResult(), "bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NumericConstantOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static ParseResult parseNumericConstantOp(OpAsmParser &parser,
|
||||||
|
OperationState *result) {
|
||||||
|
Attribute valueAttr;
|
||||||
|
if (parser.parseOptionalAttrDict(result->attributes) ||
|
||||||
|
parser.parseAttribute(valueAttr, "value", result->attributes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// If not an Integer or Float attr (which carry the type in the attr),
|
||||||
|
// expect a trailing type.
|
||||||
|
Type type;
|
||||||
|
if (valueAttr.isa<IntegerAttr>() || valueAttr.isa<FloatAttr>())
|
||||||
|
type = valueAttr.getType();
|
||||||
|
else if (parser.parseColonType(type))
|
||||||
|
return failure();
|
||||||
|
return parser.addTypeToList(type, result->types);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter &p, NumericConstantOp op) {
|
||||||
|
p << "basicpy.numeric_constant ";
|
||||||
|
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
|
||||||
|
|
||||||
|
if (op.getAttrs().size() > 1)
|
||||||
|
p << ' ';
|
||||||
|
p << op.value();
|
||||||
|
|
||||||
|
// If not an Integer or Float attr, expect a trailing type.
|
||||||
|
if (!op.value().isa<IntegerAttr>() && !op.value().isa<FloatAttr>())
|
||||||
|
p << " : " << op.getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(NumericConstantOp &op) {
|
||||||
|
auto value = op.value();
|
||||||
|
if (!value)
|
||||||
|
return op.emitOpError("requires a 'value' attribute");
|
||||||
|
auto type = op.getType();
|
||||||
|
|
||||||
|
if (type.isa<FloatType>()) {
|
||||||
|
if (!value.isa<FloatAttr>())
|
||||||
|
return op.emitOpError("requires 'value' to be a floating point constant");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||||
|
if (!value.isa<IntegerAttr>())
|
||||||
|
return op.emitOpError("requires 'value' to be an integer constant");
|
||||||
|
if (intType.getWidth() == 1)
|
||||||
|
return op.emitOpError("cannot have an i1 type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type.isa<ComplexType>()) {
|
||||||
|
if (auto complexComps = value.dyn_cast<ArrayAttr>()) {
|
||||||
|
if (complexComps.size() == 2) {
|
||||||
|
auto realValue = complexComps[0].dyn_cast<FloatAttr>();
|
||||||
|
auto imagValue = complexComps[1].dyn_cast<FloatAttr>();
|
||||||
|
if (realValue && imagValue &&
|
||||||
|
realValue.getType() == imagValue.getType())
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return op.emitOpError("requires 'value' to be a two element array of "
|
||||||
|
"floating point complex number components");
|
||||||
|
}
|
||||||
|
|
||||||
|
return op.emitOpError("unsupported basicpy.numeric_constant type");
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult NumericConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.empty() && "numeric_constant has no operands");
|
||||||
|
return value();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NumericConstantOp::getAsmResultNames(
|
||||||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
|
Type type = getType();
|
||||||
|
if (auto intCst = value().dyn_cast<IntegerAttr>()) {
|
||||||
|
IntegerType intTy = type.dyn_cast<IntegerType>();
|
||||||
|
APInt intValue = intCst.getValue();
|
||||||
|
|
||||||
|
// Otherwise, build a complex name with the value and type.
|
||||||
|
SmallString<32> specialNameBuffer;
|
||||||
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
||||||
|
specialName << "num";
|
||||||
|
if (intTy.isSigned())
|
||||||
|
specialName << intValue.getSExtValue();
|
||||||
|
else
|
||||||
|
specialName << intValue.getZExtValue();
|
||||||
|
if (intTy)
|
||||||
|
specialName << '_' << type;
|
||||||
|
setNameFn(getResult(), specialName.str());
|
||||||
|
} else {
|
||||||
|
setNameFn(getResult(), "num");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ExecOp
|
// ExecOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -54,7 +164,7 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printExecOp(OpAsmPrinter &p, ExecOp op) {
|
static void print(OpAsmPrinter &p, ExecOp op) {
|
||||||
p << op.getOperationName();
|
p << op.getOperationName();
|
||||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||||
p.printRegion(op.body());
|
p.printRegion(op.body());
|
||||||
|
@ -64,7 +174,7 @@ static void printExecOp(OpAsmPrinter &p, ExecOp op) {
|
||||||
// FuncTemplateCallOp
|
// FuncTemplateCallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) {
|
static LogicalResult verify(FuncTemplateCallOp op) {
|
||||||
auto argNames = op.arg_names();
|
auto argNames = op.arg_names();
|
||||||
if (argNames.size() > op.args().size()) {
|
if (argNames.size() > op.args().size()) {
|
||||||
return op.emitOpError() << "expected <= kw arg names vs args";
|
return op.emitOpError() << "expected <= kw arg names vs args";
|
||||||
|
@ -108,7 +218,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
|
static void print(OpAsmPrinter &p, FuncTemplateOp op) {
|
||||||
p << op.getOperationName() << " ";
|
p << op.getOperationName() << " ";
|
||||||
p.printSymbolName(op.getName());
|
p.printSymbolName(op.getName());
|
||||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||||
|
@ -116,7 +226,7 @@ static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
|
||||||
p.printRegion(op.body());
|
p.printRegion(op.body());
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBasicpyOp(FuncTemplateOp op) {
|
static LogicalResult verify(FuncTemplateOp op) {
|
||||||
Block *body = op.getBody();
|
Block *body = op.getBody();
|
||||||
for (auto &childOp : body->getOperations()) {
|
for (auto &childOp : body->getOperations()) {
|
||||||
if (!llvm::isa<FuncOp>(childOp) &&
|
if (!llvm::isa<FuncOp>(childOp) &&
|
||||||
|
@ -151,7 +261,7 @@ static ParseResult parseSlotObjectMakeOp(OpAsmParser &parser,
|
||||||
parser.getNameLoc(), result->operands);
|
parser.getNameLoc(), result->operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printSlotObjectMakeOp(OpAsmPrinter &p, SlotObjectMakeOp op) {
|
static void print(OpAsmPrinter &p, SlotObjectMakeOp op) {
|
||||||
// If the argument types do not match the result type slots, then
|
// If the argument types do not match the result type slots, then
|
||||||
// print the generic form.
|
// print the generic form.
|
||||||
auto canCustomPrint = ([&]() -> bool {
|
auto canCustomPrint = ([&]() -> bool {
|
||||||
|
@ -218,7 +328,7 @@ static ParseResult parseSlotObjectGetOp(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printSlotObjectGetOp(OpAsmPrinter &p, SlotObjectGetOp op) {
|
static void print(OpAsmPrinter &p, SlotObjectGetOp op) {
|
||||||
// If the argument types do not match the result type slots, then
|
// If the argument types do not match the result type slots, then
|
||||||
// print the generic form.
|
// print the generic form.
|
||||||
auto canCustomPrint = ([&]() -> bool {
|
auto canCustomPrint = ([&]() -> bool {
|
||||||
|
@ -262,6 +372,11 @@ OpFoldResult StrConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return valueAttr();
|
return valueAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void StrConstantOp::getAsmResultNames(
|
||||||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
|
setNameFn(getResult(), "str");
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// UnknownCastOp
|
// UnknownCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -287,9 +402,5 @@ void UnknownCastOp::getCanonicalizationPatterns(
|
||||||
patterns.insert<ElideIdentityUnknownCast>(context);
|
patterns.insert<ElideIdentityUnknownCast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Basicpy
|
|
||||||
} // namespace NPCOMP
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.cpp.inc"
|
||||||
|
|
|
@ -342,7 +342,7 @@ public:
|
||||||
op);
|
op);
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) {
|
if (auto op = dyn_cast<AsI1Op>(childOp)) {
|
||||||
// Note that the result is always i1 and not subject to type
|
// Note that the result is always i1 and not subject to type
|
||||||
// inference.
|
// inference.
|
||||||
equations.getTypeNode(op.operand());
|
equations.getTypeNode(op.operand());
|
||||||
|
|
|
@ -140,7 +140,7 @@ public:
|
||||||
// addSubtypeConstraint(op.false_value(), op.true_value(), op);
|
// addSubtypeConstraint(op.false_value(), op.true_value(), op);
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
if (auto op = dyn_cast<AsPredicateValueOp>(childOp)) {
|
if (auto op = dyn_cast<AsI1Op>(childOp)) {
|
||||||
// Note that the result is always i1 and not subject to type
|
// Note that the result is always i1 and not subject to type
|
||||||
// inference.
|
// inference.
|
||||||
resolveValueType(op.operand());
|
resolveValueType(op.operand());
|
||||||
|
|
|
@ -255,7 +255,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
next_value = self.sub_evaluate(next_node)
|
next_value = self.sub_evaluate(next_node)
|
||||||
if not next_nodes:
|
if not next_nodes:
|
||||||
return next_value
|
return next_value
|
||||||
condition_value = ir_h.basicpy_as_predicate_value_op(next_value).result
|
condition_value = ir_h.basicpy_as_i1_op(next_value).result
|
||||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||||
condition_value, True)
|
condition_value, True)
|
||||||
orig_ip = ir_h.builder.insertion_point
|
orig_ip = ir_h.builder.insertion_point
|
||||||
|
@ -347,8 +347,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
|
|
||||||
def visit_IfExp(self, ast_node):
|
def visit_IfExp(self, ast_node):
|
||||||
ir_h = self.fctx.ir_h
|
ir_h = self.fctx.ir_h
|
||||||
test_result = ir_h.basicpy_as_predicate_value_op(self.sub_evaluate(
|
test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result
|
||||||
ast_node.test)).result
|
|
||||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||||
test_result, True)
|
test_result, True)
|
||||||
|
|
||||||
|
@ -386,7 +385,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
operand_value = self.sub_evaluate(ast_node.operand)
|
operand_value = self.sub_evaluate(ast_node.operand)
|
||||||
if isinstance(op, ast.Not):
|
if isinstance(op, ast.Not):
|
||||||
# Special handling for logical-not.
|
# Special handling for logical-not.
|
||||||
condition_value = ir_h.basicpy_as_predicate_value_op(operand_value).result
|
condition_value = ir_h.basicpy_as_i1_op(operand_value).result
|
||||||
true_value = ir_h.basicpy_bool_constant_op(True).result
|
true_value = ir_h.basicpy_bool_constant_op(True).result
|
||||||
false_value = ir_h.basicpy_bool_constant_op(False).result
|
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||||
self.value = ir_h.select_op(condition_value, false_value,
|
self.value = ir_h.select_op(condition_value, false_value,
|
||||||
|
|
|
@ -90,8 +90,8 @@ class DialectHelper(_BaseDialectHelper):
|
||||||
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
||||||
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
||||||
|
|
||||||
def basicpy_as_predicate_value_op(self, value):
|
def basicpy_as_i1_op(self, value):
|
||||||
return self.op("basicpy.as_predicate_value", [self.i1_type], [value])
|
return self.op("basicpy.as_i1", [self.i1_type], [value])
|
||||||
|
|
||||||
def basicpy_unknown_cast_op(self, result_type, operand):
|
def basicpy_unknown_cast_op(self, result_type, operand):
|
||||||
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
||||||
|
|
|
@ -7,9 +7,66 @@ func @unknown_cast_elide(%arg0 : i32) -> i32 {
|
||||||
return %0 : i32
|
return %0 : i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @unknown_cast_preserve
|
// CHECK-LABEL: func @unknown_cast_preserve
|
||||||
func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType {
|
func @unknown_cast_preserve(%arg0 : i32) -> !basicpy.UnknownType {
|
||||||
// CHECK: basicpy.unknown_cast
|
// CHECK: basicpy.unknown_cast
|
||||||
%0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType
|
%0 = basicpy.unknown_cast %arg0 : i32 -> !basicpy.UnknownType
|
||||||
return %0 : !basicpy.UnknownType
|
return %0 : !basicpy.UnknownType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @numeric_constant_si32
|
||||||
|
func @numeric_constant_si32() -> si32 {
|
||||||
|
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
|
||||||
|
%0 = basicpy.numeric_constant -1 : si32
|
||||||
|
return %0 : si32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @numeric_constant_ui32
|
||||||
|
func @numeric_constant_ui32() -> ui32 {
|
||||||
|
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
|
||||||
|
%0 = basicpy.numeric_constant 1 : ui32
|
||||||
|
return %0 : ui32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @numeric_constant_f32
|
||||||
|
func @numeric_constant_f32() -> f32 {
|
||||||
|
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
|
||||||
|
%0 = basicpy.numeric_constant 2.0 : f32
|
||||||
|
return %0 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @numeric_constant_complex_f32
|
||||||
|
func @numeric_constant_complex_f32() -> complex<f32> {
|
||||||
|
// CHECK: %num = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
||||||
|
%0 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
||||||
|
return %0 : complex<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @bool_constant
|
||||||
|
func @bool_constant() -> !basicpy.BoolType {
|
||||||
|
// CHECK: %bool = basicpy.bool_constant true
|
||||||
|
%0 = basicpy.bool_constant true
|
||||||
|
return %0 : !basicpy.BoolType
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @bytes_constant
|
||||||
|
func @bytes_constant() -> !basicpy.BytesType {
|
||||||
|
// CHECK: %bytes = basicpy.bytes_constant "foobar"
|
||||||
|
%0 = basicpy.bytes_constant "foobar"
|
||||||
|
return %0 : !basicpy.BytesType
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @str_constant
|
||||||
|
func @str_constant() -> !basicpy.StrType {
|
||||||
|
// CHECK: %str = basicpy.str_constant "foobar"
|
||||||
|
%0 = basicpy.str_constant "foobar"
|
||||||
|
return %0 : !basicpy.StrType
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s
|
||||||
|
|
||||||
|
func @numeric_constant_string_attr() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be an integer constant}}
|
||||||
|
%0 = "basicpy.numeric_constant"() {value="somestring" : i32} : () -> (i32)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_bool() {
|
||||||
|
// expected-error @+1 {{cannot have an i1 type}}
|
||||||
|
%0 = "basicpy.numeric_constant"() {value = true} : () -> (i1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_mismatch_int() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be a floating point constant}}
|
||||||
|
%0 = "basicpy.numeric_constant"() {value = 1 : i32} : () -> (f64)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_mismatch_float() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be an integer constant}}
|
||||||
|
%0 = "basicpy.numeric_constant"() {value = 1.0 : f32} : () -> (i32)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_complex_wrong_arity() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||||
|
%3 = basicpy.numeric_constant [2.0 : f32] : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_complex_mismatch_type_real() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||||
|
%3 = basicpy.numeric_constant [2.0 : f64, 3.0 : f32] : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @numeric_constant_complex_mismatch_type_imag() {
|
||||||
|
// expected-error @+1 {{op requires 'value' to be a two element array of floating point complex number components}}
|
||||||
|
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f16] : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -24,4 +24,16 @@ func @build_tuple_generic(%arg0 : si32, %arg1 : si32) -> !basicpy.TupleType {
|
||||||
return %0 : !basicpy.TupleType
|
return %0 : !basicpy.TupleType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @numeric_constant
|
||||||
|
func @numeric_constant() {
|
||||||
|
// CHECK: %num-1_si32 = basicpy.numeric_constant -1 : si32
|
||||||
|
%0 = basicpy.numeric_constant -1 : si32
|
||||||
|
// CHECK: %num1_ui32 = basicpy.numeric_constant 1 : ui32
|
||||||
|
%1 = basicpy.numeric_constant 1 : ui32
|
||||||
|
// CHECK: %num = basicpy.numeric_constant 2.000000e+00 : f32
|
||||||
|
%2 = basicpy.numeric_constant 2.0 : f32
|
||||||
|
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
||||||
|
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -14,9 +14,9 @@ def logical_and():
|
||||||
x = 1
|
x = 1
|
||||||
y = 0
|
y = 0
|
||||||
z = 2
|
z = 2
|
||||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]]
|
# CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
|
||||||
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]]
|
# CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
|
||||||
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||||
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
|
# CHECK: %[[ZCAST:.*]] = basicpy.unknown_cast %[[Z]]
|
||||||
# CHECK: scf.yield %[[ZCAST]]
|
# CHECK: scf.yield %[[ZCAST]]
|
||||||
|
@ -39,12 +39,12 @@ def logical_or():
|
||||||
# CHECK: %[[X:.*]] = constant 0
|
# CHECK: %[[X:.*]] = constant 0
|
||||||
# CHECK: %[[Y:.*]] = constant 1
|
# CHECK: %[[Y:.*]] = constant 1
|
||||||
# CHECK: %[[Z:.*]] = constant 2
|
# CHECK: %[[Z:.*]] = constant 2
|
||||||
# CHECK: %[[XBOOL:.*]] = basicpy.as_predicate_value %[[X]]
|
# CHECK: %[[XBOOL:.*]] = basicpy.as_i1 %[[X]]
|
||||||
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
# CHECK: %[[IF0:.*]] = scf.if %[[XBOOL]] -> (!basicpy.UnknownType) {
|
||||||
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
|
# CHECK: %[[XCAST:.*]] = basicpy.unknown_cast %[[X]]
|
||||||
# CHECK: scf.yield %[[XCAST]]
|
# CHECK: scf.yield %[[XCAST]]
|
||||||
# CHECK: } else {
|
# CHECK: } else {
|
||||||
# CHECK: %[[YBOOL:.*]] = basicpy.as_predicate_value %[[Y]]
|
# CHECK: %[[YBOOL:.*]] = basicpy.as_i1 %[[Y]]
|
||||||
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
# CHECK: %[[IF1:.*]] = scf.if %[[YBOOL]] -> (!basicpy.UnknownType) {
|
||||||
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
|
# CHECK: %[[YCAST:.*]] = basicpy.unknown_cast %[[Y]]
|
||||||
# CHECK: scf.yield %[[YCAST]]
|
# CHECK: scf.yield %[[YCAST]]
|
||||||
|
@ -68,7 +68,7 @@ def logical_not():
|
||||||
x = 1
|
x = 1
|
||||||
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
|
# CHECK-DAG: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||||
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
|
# CHECK-DAG: %[[FALSE:.*]] = basicpy.bool_constant false
|
||||||
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]]
|
# CHECK-DAG: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
|
||||||
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
# CHECK-DAG: %{{.*}} = select %[[CONDITION]], %[[FALSE]], %[[TRUE]] : !basicpy.BoolType
|
||||||
return not x
|
return not x
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def logical_not():
|
||||||
def conditional():
|
def conditional():
|
||||||
# CHECK: %[[X:.*]] = constant 1
|
# CHECK: %[[X:.*]] = constant 1
|
||||||
x = 1
|
x = 1
|
||||||
# CHECK: %[[CONDITION:.*]] = basicpy.as_predicate_value %[[X]]
|
# CHECK: %[[CONDITION:.*]] = basicpy.as_i1 %[[X]]
|
||||||
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
|
# CHECK: %[[IF0:.*]] = scf.if %[[CONDITION]] -> (!basicpy.UnknownType) {
|
||||||
# CHECK: %[[TWO:.*]] = constant 2 : i64
|
# CHECK: %[[TWO:.*]] = constant 2 : i64
|
||||||
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]
|
# CHECK: %[[TWO_CAST:.*]] = basicpy.unknown_cast %[[TWO]]
|
||||||
|
|
Loading…
Reference in New Issue