Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

feat: add support for Decimal and Decimal score types #110

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public class PythonClassTranslator {
// $ is illegal in variables/methods in Python
public static final String TYPE_FIELD_NAME = "$TYPE";
public static final String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE";
private static final String JAVA_METHOD_PREFIX = "$method$";
private static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping";
public static final String JAVA_METHOD_PREFIX = "$method$";
public static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping";

public record PreparedClassInfo(PythonLikeType type, String className, String classInternalName) {
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.timefold.jpyinterpreter.implementors;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.IdentityHashMap;
import java.util.Iterator;
Expand Down Expand Up @@ -31,6 +32,7 @@
import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple;
import ai.timefold.jpyinterpreter.types.errors.TypeError;
import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean;
import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal;
import ai.timefold.jpyinterpreter.types.numeric.PythonFloat;
import ai.timefold.jpyinterpreter.types.numeric.PythonInteger;
import ai.timefold.jpyinterpreter.types.numeric.PythonNumber;
Expand Down Expand Up @@ -65,76 +67,78 @@ public static PythonLikeObject wrapJavaObject(Object object, Map<Object, PythonL
return existingObject;
}

if (object instanceof OpaqueJavaReference) {
return ((OpaqueJavaReference) object).proxy();
if (object instanceof OpaqueJavaReference opaqueJavaReference) {
return opaqueJavaReference.proxy();
}

if (object instanceof PythonLikeObject) {
if (object instanceof PythonLikeObject instance) {
// Object already a PythonLikeObject; need to do nothing
return (PythonLikeObject) object;
return instance;
}

if (object instanceof Byte || object instanceof Short || object instanceof Integer || object instanceof Long) {
return PythonInteger.valueOf(((Number) object).longValue());
}

if (object instanceof BigInteger) {
return PythonInteger.valueOf((BigInteger) object);
if (object instanceof BigInteger integer) {
return PythonInteger.valueOf(integer);
}

if (object instanceof BigDecimal decimal) {
return new PythonDecimal(decimal);
}

if (object instanceof Float || object instanceof Double) {
return PythonFloat.valueOf(((Number) object).doubleValue());
}

if (object instanceof Boolean) {
return PythonBoolean.valueOf((Boolean) object);
if (object instanceof Boolean booleanValue) {
return PythonBoolean.valueOf(booleanValue);
}

if (object instanceof String) {
return PythonString.valueOf((String) object);
if (object instanceof String string) {
return PythonString.valueOf(string);
}

if (object instanceof Iterator) {
return new DelegatePythonIterator<>((Iterator) object);
if (object instanceof Iterator<?> iterator) {
return new DelegatePythonIterator<>(iterator);
}

if (object instanceof List) {
PythonLikeList out = new PythonLikeList();
if (object instanceof List<?> list) {
PythonLikeList<?> out = new PythonLikeList<>();
createdObjectMap.put(object, out);
for (Object item : (List) object) {
for (Object item : list) {
out.add(wrapJavaObject(item));
}
return out;
}

if (object instanceof Set) {
PythonLikeSet out = new PythonLikeSet();
if (object instanceof Set<?> set) {
PythonLikeSet<?> out = new PythonLikeSet<>();
createdObjectMap.put(object, out);
for (Object item : (Set) object) {
for (Object item : set) {
out.add(wrapJavaObject(item));
}
return out;
}

if (object instanceof Map) {
PythonLikeDict out = new PythonLikeDict();
if (object instanceof Map<?, ?> map) {
PythonLikeDict<?, ?> out = new PythonLikeDict<>();
createdObjectMap.put(object, out);
Set<Map.Entry<?, ?>> entrySet = ((Map) object).entrySet();
for (Map.Entry<?, ?> entry : entrySet) {
var entrySet = map.entrySet();
for (var entry : entrySet) {
out.put(wrapJavaObject(entry.getKey()), wrapJavaObject(entry.getValue()));
}
return out;
}

if (object instanceof Class) {
Class<?> maybeFunctionClass = (Class<?>) object;
if (Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) {
return new PythonCode((Class<? extends PythonLikeFunction>) maybeFunctionClass);
}
if (object instanceof Class<?> maybeFunctionClass &&
Set.of(maybeFunctionClass.getInterfaces()).contains(PythonLikeFunction.class)) {
return new PythonCode((Class<? extends PythonLikeFunction>) maybeFunctionClass);
}

if (object instanceof OpaquePythonReference) {
return new PythonObjectWrapper((OpaquePythonReference) object);
if (object instanceof OpaquePythonReference opaquePythonReference) {
return new PythonObjectWrapper(opaquePythonReference);
}

// Default: return a JavaObjectWrapper
Expand All @@ -161,6 +165,10 @@ public static PythonLikeType getPythonLikeType(Class<?> javaClass) {
return BuiltinTypes.INT_TYPE;
}

if (BigDecimal.class.equals(javaClass) || PythonDecimal.class.equals(javaClass)) {
return BuiltinTypes.DECIMAL_TYPE;
}

if (float.class.equals(javaClass) || double.class.equals(javaClass) ||
Float.class.equals(javaClass) || Double.class.equals(javaClass) ||
PythonFloat.class.equals(javaClass)) {
Expand Down Expand Up @@ -254,8 +262,7 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho
return null;
}

if (object instanceof JavaObjectWrapper) {
JavaObjectWrapper wrappedObject = (JavaObjectWrapper) object;
if (object instanceof JavaObjectWrapper wrappedObject) {
Object javaObject = wrappedObject.getWrappedObject();
if (!type.isAssignableFrom(javaObject.getClass())) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(javaObject.getClass()) + ") to ("
Expand All @@ -266,14 +273,13 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho

if (type.equals(byte.class) || type.equals(short.class) || type.equals(int.class) || type.equals(long.class) ||
type.equals(float.class) || type.equals(double.class) || Number.class.isAssignableFrom(type)) {
if (!(object instanceof PythonNumber)) {
if (!(object instanceof PythonNumber pythonNumber)) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to ("
+ getPythonLikeType(type) + ").");
}
PythonNumber pythonNumber = (PythonNumber) object;
Number value = pythonNumber.getValue();

if (type.equals(BigInteger.class)) {
if (type.equals(BigInteger.class) || type.equals(BigDecimal.class)) {
return (T) value;
}

Expand Down Expand Up @@ -303,11 +309,10 @@ public static <T> T convertPythonObjectToJavaType(Class<? extends T> type, Pytho
}

if (type.equals(boolean.class) || type.equals(Boolean.class)) {
if (!(object instanceof PythonBoolean)) {
if (!(object instanceof PythonBoolean pythonBoolean)) {
throw new TypeError("Cannot convert from (" + getPythonLikeType(object.getClass()) + ") to ("
+ getPythonLikeType(type) + ").");
}
PythonBoolean pythonBoolean = (PythonBoolean) object;
return (T) (Boolean) pythonBoolean.getBooleanValue();
}

Expand Down Expand Up @@ -335,6 +340,53 @@ public static void loadName(MethodVisitor methodVisitor, String name) {
false);
}

private record ReturnValueOpDescriptor(
String wrapperClassName,
String methodName,
String methodDescriptor,
int opcode,
boolean noConversionNeeded) {
public static ReturnValueOpDescriptor noConversion() {
return new ReturnValueOpDescriptor("", "", "",
Opcodes.ARETURN, true);
}

public static ReturnValueOpDescriptor forNumeric(String methodName,
String methodDescriptor,
int opcode) {
return new ReturnValueOpDescriptor(Type.getInternalName(Number.class), methodName, methodDescriptor, opcode,
false);
}
}

private static final Map<Type, ReturnValueOpDescriptor> numericReturnValueOpDescriptorMap = Map.of(
Type.BYTE_TYPE, ReturnValueOpDescriptor.forNumeric(
"byteValue",
Type.getMethodDescriptor(Type.BYTE_TYPE),
Opcodes.IRETURN),
Type.SHORT_TYPE, ReturnValueOpDescriptor.forNumeric(
"shortValue",
Type.getMethodDescriptor(Type.SHORT_TYPE),
Opcodes.IRETURN),
Type.INT_TYPE, ReturnValueOpDescriptor.forNumeric(
"intValue",
Type.getMethodDescriptor(Type.INT_TYPE),
Opcodes.IRETURN),
Type.LONG_TYPE, ReturnValueOpDescriptor.forNumeric(
"longValue",
Type.getMethodDescriptor(Type.LONG_TYPE),
Opcodes.LRETURN),
Type.FLOAT_TYPE, ReturnValueOpDescriptor.forNumeric(
"floatValue",
Type.getMethodDescriptor(Type.FLOAT_TYPE),
Opcodes.FRETURN),
Type.DOUBLE_TYPE, ReturnValueOpDescriptor.forNumeric(
"doubleValue",
Type.getMethodDescriptor(Type.DOUBLE_TYPE),
Opcodes.DRETURN),
Type.getType(BigInteger.class), ReturnValueOpDescriptor.noConversion(),
Type.getType(BigDecimal.class), ReturnValueOpDescriptor.noConversion());

/**
* If {@code method} return type is not void, convert TOS into its Java equivalent and return it.
* If {@code method} return type is void, immediately return.
Expand All @@ -344,67 +396,36 @@ public static void loadName(MethodVisitor methodVisitor, String name) {
public static void returnValue(MethodVisitor methodVisitor, MethodDescriptor method, StackMetadata stackMetadata) {
Type returnAsmType = method.getReturnType();

if (Type.CHAR_TYPE.equals(returnAsmType)) {
throw new IllegalStateException("Unhandled case for primitive type (char).");
}

if (Type.VOID_TYPE.equals(returnAsmType)) {
methodVisitor.visitInsn(Opcodes.RETURN);
return;
}

if (Type.BYTE_TYPE.equals(returnAsmType) ||
Type.CHAR_TYPE.equals(returnAsmType) ||
Type.SHORT_TYPE.equals(returnAsmType) ||
Type.INT_TYPE.equals(returnAsmType) ||
Type.LONG_TYPE.equals(returnAsmType) ||
Type.FLOAT_TYPE.equals(returnAsmType) ||
Type.DOUBLE_TYPE.equals(returnAsmType)) {
if (numericReturnValueOpDescriptorMap.containsKey(returnAsmType)) {
var returnValueOpDescriptor = numericReturnValueOpDescriptorMap.get(returnAsmType);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(PythonNumber.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE,
Type.getInternalName(PythonNumber.class),
"getValue",
Type.getMethodDescriptor(Type.getType(Number.class)),
true);
String wrapperClassName = null;
String methodName = null;
String methodDescriptor = null;
int returnOpcode = 0;

if (Type.BYTE_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "byteValue";
methodDescriptor = Type.getMethodDescriptor(Type.BYTE_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.CHAR_TYPE.equals(returnAsmType)) {
throw new IllegalStateException("Unhandled case for primitive type (char).");
// returnOpcode = Opcodes.IRETURN;
} else if (Type.SHORT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "shortValue";
methodDescriptor = Type.getMethodDescriptor(Type.SHORT_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.INT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "intValue";
methodDescriptor = Type.getMethodDescriptor(Type.INT_TYPE);
returnOpcode = Opcodes.IRETURN;
} else if (Type.FLOAT_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "floatValue";
methodDescriptor = Type.getMethodDescriptor(Type.FLOAT_TYPE);
returnOpcode = Opcodes.FRETURN;
} else if (Type.LONG_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "longValue";
methodDescriptor = Type.getMethodDescriptor(Type.LONG_TYPE);
returnOpcode = Opcodes.LRETURN;
} else if (Type.DOUBLE_TYPE.equals(returnAsmType)) {
wrapperClassName = Type.getInternalName(Number.class);
methodName = "doubleValue";
methodDescriptor = Type.getMethodDescriptor(Type.DOUBLE_TYPE);
returnOpcode = Opcodes.DRETURN;

if (returnValueOpDescriptor.noConversionNeeded) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, returnAsmType.getInternalName());
methodVisitor.visitInsn(Opcodes.ARETURN);
return;
}

methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
wrapperClassName, methodName, methodDescriptor,
returnValueOpDescriptor.wrapperClassName,
returnValueOpDescriptor.methodName,
returnValueOpDescriptor.methodDescriptor,
false);
methodVisitor.visitInsn(returnOpcode);
methodVisitor.visitInsn(returnValueOpDescriptor.opcode);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.timefold.jpyinterpreter.types.collections.view.DictValueView;
import ai.timefold.jpyinterpreter.types.numeric.PythonBoolean;
import ai.timefold.jpyinterpreter.types.numeric.PythonComplex;
import ai.timefold.jpyinterpreter.types.numeric.PythonDecimal;
import ai.timefold.jpyinterpreter.types.numeric.PythonFloat;
import ai.timefold.jpyinterpreter.types.numeric.PythonInteger;
import ai.timefold.jpyinterpreter.types.numeric.PythonNumber;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class BuiltinTypes {
public static final PythonLikeType BOOLEAN_TYPE = new PythonLikeType("bool", PythonBoolean.class, List.of(INT_TYPE));
public static final PythonLikeType FLOAT_TYPE = new PythonLikeType("float", PythonFloat.class, List.of(NUMBER_TYPE));
public final static PythonLikeType COMPLEX_TYPE = new PythonLikeType("complex", PythonComplex.class, List.of(NUMBER_TYPE));
public final static PythonLikeType DECIMAL_TYPE = new PythonLikeType("Decimal", PythonDecimal.class, List.of(NUMBER_TYPE));

public static final PythonLikeType STRING_TYPE = new PythonLikeType("str", PythonString.class, List.of(BASE_TYPE));
public static final PythonLikeType BYTES_TYPE = new PythonLikeType("bytes", PythonBytes.class, List.of(BASE_TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ public PythonLikeTuple<T> createNewInstance() {
return new PythonLikeTuple<>();
}

public static PythonLikeTuple fromItems(PythonLikeObject... items) {
PythonLikeTuple result = new PythonLikeTuple();
public static <T extends PythonLikeObject> PythonLikeTuple<T> fromItems(T... items) {
PythonLikeTuple<T> result = new PythonLikeTuple<>();
Collections.addAll(result, items);
return result;
}

public static PythonLikeTuple fromList(List<PythonLikeObject> other) {
PythonLikeTuple result = new PythonLikeTuple();
public static <T extends PythonLikeObject> PythonLikeTuple<T> fromList(List<T> other) {
PythonLikeTuple<T> result = new PythonLikeTuple<>();
result.addAll(other);
return result;
}
Expand Down
Loading
Loading