import React, { useState } from 'react';
import { withAuthenticator } from '@aws-amplify/ui-react';
import { Amplify } from 'aws-amplify';
import { fetchAuthSession } from 'aws-amplify/auth';
import awsconfig from './aws-exports';
import '@aws-amplify/ui-react/styles.css';
import './SubmitTFTask.css';
import { v4 as uuidv4 } from 'uuid';
import { createS3Client, getIdentityPoolUserId } from './chunkHelper'; // Removed sanitizeIdentityPoolUserId import
import { S3Client, PutObjectCommand } from "@aws-sdk/client-s3"; 

Amplify.configure(awsconfig);

function SubmitTFTask() {
  const [xsFile, setXsFile] = useState(null); // For xs.csv
  const [ysFile, setYsFile] = useState(null); // For ys.csv
  const [modelFile, setModelFile] = useState(null); // For model.json
  const [weightsFile, setWeightsFile] = useState(null); // For weights.bin
  const [inputShape, setInputShape] = useState(''); // New state for input shape
  const [outputShape, setOutputShape] = useState(''); // New state for output shape
  const [message, setMessage] = useState('');
  const [batchId, setBatchId] = useState(null);
  const [taskMetadata, setTaskMetadata] = useState({});
  const [isUploading, setIsUploading] = useState(false);

  const handleXsFileChange = (event) => {
    setXsFile(event.target.files[0]);
  };

  const handleYsFileChange = (event) => {
    setYsFile(event.target.files[0]);
  };

  const handleModelFileChange = (event) => {
    setModelFile(event.target.files[0]);
  };

  const handleWeightsFileChange = (event) => {
    setWeightsFile(event.target.files[0]);
  };

  const handleTaskMetadataChange = (field, value) => {
    setTaskMetadata(prevMetadata => ({ ...prevMetadata, [field]: value }));
  };

  const uploadFileToS3 = async (s3Client, file, newBatchId, owner, type, chunkId = null) => {
    let fileKey;
    
    if (type === 'model' || type === 'weights') {
      fileKey = `chunks/${owner}/${newBatchId}/model/${file.name}`;
    } else if (type === 'xs') {
      fileKey = `chunks/${owner}/${newBatchId}/${chunkId}/xs/${file.name}`;
    } else if (type === 'ys') {
      fileKey = `chunks/${owner}/${newBatchId}/${chunkId}/ys/${file.name}`;
    } else {
      fileKey = `chunks/${owner}/${newBatchId}/${file.name}`;
    }

    const putObjectParams = {
      Bucket: "my-data-chunks-bucket",
      Key: fileKey,
      Body: file,
      ContentType: file.type,
      Metadata: {
        'chunk-id': chunkId,
        'task-type': 'tensorflow',
        'batch-id': newBatchId,
        'owner': owner,
        'file-type': type,
        'input-shape': inputShape,
        'output-shape': outputShape,
        'epochs': taskMetadata.epochs || '10',
        'loss': taskMetadata.loss || 'meanSquaredError',
        'optimizer': taskMetadata.optimizer || 'adam'
      }
    };

    await s3Client.send(new PutObjectCommand(putObjectParams));
    console.log(`${type} file uploaded to: ${fileKey}`);
    return fileKey;
  };

  const createRiddleSourceEntry = async (batchId, fileName, taskType, taskMetadata, owner, modelS3Location) => {
    const path = '/start';

    const session = await fetchAuthSession(); 
    const idToken = session.tokens.idToken.toString(); 

    const init = {
      body: JSON.stringify({
        batchId,
        fileName,
        taskType,
        taskMetadata,
        owner,
        modelS3Location
      }),
      headers: {
        Authorization: `Bearer ${idToken}`,
        'Content-Type': 'application/json'
      },
    };

    const response = await fetch(`https://0pawfsvt1a.execute-api.us-east-1.amazonaws.com/dev${path}`, {
      method: 'POST',
      headers: init.headers,
      body: init.body
    });

    if (!response.ok) {
      const errorText = await response.text();
      throw new Error(`HTTP error! status: ${response.status}, body: ${errorText}`);
    }

    console.log('Riddle Source entry created:', { batchId, taskType, taskMetadata, modelS3Location });
    return await response.json();
  };

  const handleFileUpload = async () => {
    if (!xsFile || !ysFile || !modelFile || !weightsFile || !inputShape || !outputShape) {
      setMessage('Please select all required files and provide input/output shapes.');
      return;
    }

    setIsUploading(true);
    setMessage('Starting file upload...');

    try {
      const session = await fetchAuthSession();
      const idToken = session.tokens.idToken.toString();
      const identityPoolUserId = await getIdentityPoolUserId(idToken);
      const owner = identityPoolUserId; // Use full Identity Pool User ID
      const newBatchId = uuidv4();
      setMessage(prevMessage => `${prevMessage}
Batch ID: ${newBatchId}`);
      setBatchId(newBatchId);

      const s3Client = createS3Client(idToken);

      // Upload model and weights at the batch level
      setMessage(prevMessage => `${prevMessage}
Uploading model file (${modelFile.name}, ${(modelFile.size / 1024).toFixed(2)} KB)...`);
      const modelS3Location = await uploadFileToS3(s3Client, modelFile, newBatchId, owner, 'model');
      setMessage(prevMessage => `${prevMessage}
Model file uploaded.`);
      setMessage(prevMessage => `${prevMessage}
Uploading weights file (${weightsFile.name}, ${(weightsFile.size / 1024).toFixed(2)} KB)...`);
      const weightsS3Location = await uploadFileToS3(s3Client, weightsFile, newBatchId, owner, 'weights');
      setMessage(prevMessage => `${prevMessage}
Weights file uploaded.`);
      
      // Store input and output shapes in metadata
      handleTaskMetadataChange('inputShape', inputShape);
      handleTaskMetadataChange('outputShape', outputShape);

      // Create Riddle Source entry at the batch level with metadata relevant for the batch
      const taskMetadataString = JSON.stringify({
        modelS3Location,
        inputShape,
        outputShape,
        numColumns: taskMetadata.numColumns || 'undefined',
        modelType: taskMetadata.modelType || 'neural-network',
        epochs: taskMetadata.epochs || 10,
        loss: taskMetadata.loss || 'meanSquaredError',
        optimizer: taskMetadata.optimizer || 'adam'
      });

      await createRiddleSourceEntry(newBatchId, xsFile.name, 'tensorflow', taskMetadataString, owner, modelS3Location);

      // Generate a unique chunk ID for the xs and ys files
      const chunkId = uuidv4();
      setMessage(prevMessage => `${prevMessage}
Chunk ID: ${chunkId}`);

      // Sequentially upload the xs.csv first
      console.log('Uploading xs file...');
      setMessage(prevMessage => `${prevMessage}
Uploading input data file (xs): ${xsFile.name} (${(xsFile.size / 1024).toFixed(2)} KB)...`);
      const xsChunkKey = await uploadFileToS3(s3Client, xsFile, newBatchId, owner, 'xs', chunkId);
      setMessage(prevMessage => `${prevMessage}
Input data file uploaded.`);
      console.log('XS file uploaded with key:', xsChunkKey);

      // After xs upload completes, upload ys.csv
      console.log('Uploading ys file...');
      setMessage(prevMessage => `${prevMessage}
Uploading output data file (ys): ${ysFile.name} (${(ysFile.size / 1024).toFixed(2)} KB)...`);
      const ysChunkKey = await uploadFileToS3(s3Client, ysFile, newBatchId, owner, 'ys', chunkId);
      setMessage(prevMessage => `${prevMessage}
Output data file uploaded.`);
      console.log('YS file uploaded with key:', ysChunkKey);

      // Set the improved message
      setMessage(prevMessage => `${prevMessage}

All files have been successfully uploaded!

Check status in <a href="/dashboard" class="dashboard-link">Dashboard</a>`);
    } catch (error) {
      console.error('Error uploading files:', error);
      setMessage(`Error uploading files: ${error.message}`);
    } finally {
      setIsUploading(false);
    }
  };

  return (
    <div className="submit-task-page">
      <div className="submit-task-heading">
        <h2>Riddle: Submit TF Task</h2>
      </div>

      {message && (
        <div className="message-box">
          <pre dangerouslySetInnerHTML={{ __html: message }}></pre>
        </div>
      )}

      <div className="submit-task-form">
        <div className="form-group row">
          <label htmlFor="xs-file-upload" className="form-label">Input Data (xs.csv)</label>
          <input type="file" id="xs-file-upload" accept=".csv" onChange={handleXsFileChange} required className="form-input" />
        </div>

        <div className="form-group row">
          <label htmlFor="ys-file-upload" className="form-label">Output Data (ys.csv)</label>
          <input type="file" id="ys-file-upload" accept=".csv" onChange={handleYsFileChange} required className="form-input" />
        </div>

        <div className="form-group row">
          <label htmlFor="model-upload" className="form-label">TF Model (model.json)</label>
          <input type="file" id="model-upload" accept=".json" onChange={handleModelFileChange} required className="form-input" />
        </div>

        <div className="form-group row">
          <label htmlFor="weights-upload" className="form-label">Weights (weights.bin)</label>
          <input type="file" id="weights-upload" accept=".bin" onChange={handleWeightsFileChange} required className="form-input" />
        </div>

        <div className="form-group row">
          <label htmlFor="input-shape" className="form-label">Input Shape</label>
          <input type="text" id="input-shape" value={inputShape} onChange={(e) => setInputShape(e.target.value)} required placeholder="e.g., [10000, 10, 1]" className="form-input" />
        </div>
        <div className="form-group row">
          <label htmlFor="output-shape" className="form-label">Output Shape</label>
          <input type="text" id="output-shape" value={outputShape} onChange={(e) => setOutputShape(e.target.value)} required placeholder="e.g., [10000, 10, 1]" className="form-input" />
        </div>
        <div className="form-group row">
          <label htmlFor="epochs" className="form-label">Epochs</label>
          <input type="number" id="epochs" value={taskMetadata.epochs || 10} onChange={(e) => handleTaskMetadataChange('epochs', e.target.value)} required className="form-input" />
        </div>
        <div className="form-group row">
          <label htmlFor="loss" className="form-label">Loss Function</label>
          <input type="text" id="loss" value={taskMetadata.loss || 'meanSquaredError'} onChange={(e) => handleTaskMetadataChange('loss', e.target.value)} required className="form-input" />
        </div>
        <div className="form-group row">
          <label htmlFor="optimizer" className="form-label">Optimizer</label>
          <input type="text" id="optimizer" value={taskMetadata.optimizer || 'adam'} onChange={(e) => handleTaskMetadataChange('optimizer', e.target.value)} required className="form-input" />
        </div>

        <div className="form-group row">
          <button onClick={handleFileUpload} disabled={isUploading} className="form-button">
            {isUploading ? 'Uploading...' : 'Upload'}
          </button>
        </div>
      </div>
    </div>
  );
}

export default withAuthenticator(SubmitTFTask);
