mirror of https://github.com/llvm/torch-mlir
Introduce a Target class and use it to define generic 32 and 64bit variants.
parent
2bc7a77f98
commit
c3d4436397
|
@ -0,0 +1,30 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
# Subset of constant tests which verify against a GenericTarget32.
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: func @integer_constants
|
||||
@import_global
|
||||
def integer_constants():
|
||||
# CHECK: %[[A:.*]] = constant 100 : i32
|
||||
a = 100
|
||||
return a
|
||||
|
||||
|
||||
# CHECK-LABEL: func @float_constants
|
||||
@import_global
|
||||
def float_constants():
|
||||
# CHECK: %[[A:.*]] = constant 2.200000e+00 : f32
|
||||
a = 2.2
|
||||
return a
|
|
@ -14,6 +14,7 @@ from npcomp.dialect import Numpy
|
|||
|
||||
from . import logging
|
||||
from .importer import *
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
"ImportFrontend",
|
||||
|
@ -32,12 +33,21 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
|||
|
||||
class ImportFrontend:
|
||||
"""Frontend for importing various entities into a Module."""
|
||||
__slots__ = [
|
||||
"_ir_context",
|
||||
"_ir_module",
|
||||
"_helper",
|
||||
"_target_factory",
|
||||
]
|
||||
|
||||
def __init__(self, ir_context: ir.MLIRContext = None):
|
||||
def __init__(self,
|
||||
ir_context: ir.MLIRContext = None,
|
||||
target_factory: TargetFactory = GenericTarget64):
|
||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||
self._ir_module = self._ir_context.new_module()
|
||||
self._helper = AllDialectHelper(self._ir_context,
|
||||
ir.OpBuilder(self._ir_context))
|
||||
self._target_factory = target_factory
|
||||
|
||||
@property
|
||||
def ir_context(self):
|
||||
|
@ -66,6 +76,7 @@ class ImportFrontend:
|
|||
h = self.ir_h
|
||||
ir_c = self.ir_context
|
||||
ir_m = self.ir_module
|
||||
target = self._target_factory(h)
|
||||
filename = inspect.getsourcefile(f)
|
||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||
source = "".join(source_lines)
|
||||
|
@ -94,7 +105,8 @@ class ImportFrontend:
|
|||
fctx = FunctionContext(ir_c=ir_c,
|
||||
ir_f=ir_f,
|
||||
ir_h=h,
|
||||
filename_ident=filename_ident)
|
||||
filename_ident=filename_ident,
|
||||
target=target)
|
||||
for f_arg, ir_arg in zip(f_params, ir_f.first_block.args):
|
||||
fctx.map_local_name(f_arg, ir_arg)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import sys
|
|||
from _npcomp.mlir import ir
|
||||
|
||||
from . import logging
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
"FunctionContext",
|
||||
|
@ -24,14 +25,16 @@ class FunctionContext:
|
|||
"ir_c",
|
||||
"ir_f",
|
||||
"ir_h",
|
||||
"target",
|
||||
"filename_ident",
|
||||
"local_name_value_map",
|
||||
]
|
||||
|
||||
def __init__(self, ir_c, ir_f, ir_h, filename_ident):
|
||||
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident):
|
||||
self.ir_c = ir_c
|
||||
self.ir_f = ir_f
|
||||
self.ir_h = ir_h
|
||||
self.target = target
|
||||
self.filename_ident = filename_ident
|
||||
self.local_name_value_map = dict()
|
||||
|
||||
|
@ -93,7 +96,8 @@ class FunctionDefImporter(BaseNodeVisitor):
|
|||
if not self._last_was_return:
|
||||
# Add a default terminator.
|
||||
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType, none_value).result
|
||||
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
none_value).result
|
||||
ir_h.return_op([none_cast])
|
||||
|
||||
def visit_Assign(self, ast_node):
|
||||
|
@ -137,6 +141,8 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
def __init__(self, fctx):
|
||||
super().__init__(fctx)
|
||||
self.value = None
|
||||
self._int_type = fctx.target.impl_int_type
|
||||
self._float_type = fctx.target.impl_float_type
|
||||
|
||||
def visit(self, node):
|
||||
super().visit(node)
|
||||
|
@ -158,13 +164,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
elif value is None:
|
||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(value, int):
|
||||
# TODO: Configurable type mapping
|
||||
ir_type = ir_h.i64_type
|
||||
ir_type = self._int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, value)
|
||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(value, float):
|
||||
# TODO: Configurable type mapping
|
||||
ir_type = ir_h.f64_type
|
||||
ir_type = self._float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, value)
|
||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(value, str):
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import *
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
"GenericTarget32",
|
||||
"GenericTarget64",
|
||||
"Target",
|
||||
"TargetFactory",
|
||||
]
|
||||
|
||||
|
||||
class Target:
|
||||
"""
|
||||
Abstract class providing configuration and hooks for a specific compilation
|
||||
target.
|
||||
"""
|
||||
__slots__ = [
|
||||
"_mlir_helper",
|
||||
]
|
||||
|
||||
def __init__(self, mlir_helper: ir.DialectHelper):
|
||||
super().__init__()
|
||||
self._mlir_helper = mlir_helper
|
||||
|
||||
@property
|
||||
def mlir_helper(self):
|
||||
return self._mlir_helper
|
||||
|
||||
@property
|
||||
def mlir_context(self):
|
||||
return self._mlir_helper.context
|
||||
|
||||
@property
|
||||
def target_name(self) -> str:
|
||||
return NotImplementedError()
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GenericTarget64(Target):
|
||||
"""A generic 64 bit target."""
|
||||
|
||||
@property
|
||||
def target_name(self) -> str:
|
||||
return "generic64"
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
return self.mlir_helper.i64_type
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
return self.mlir_helper.f64_type
|
||||
|
||||
|
||||
class GenericTarget32(Target):
|
||||
"""A generic 32 bit target (uses 32bit ints and floats)."""
|
||||
|
||||
@property
|
||||
def target_name(self) -> str:
|
||||
return "generic32"
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
return self.mlir_helper.i32_type
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
return self.mlir_helper.f32_type
|
||||
|
||||
|
||||
# Factory for producing a target (matches the Target constructor).
|
||||
TargetFactory = Callable[[ir.DialectHelper], Target]
|
Loading…
Reference in New Issue